Skip to content

Commit

Permalink
Handle insert value of multi type (#1560)
Browse files Browse the repository at this point in the history
* Handle insert value of multi type

* Fix multi agg

* now with separated

* fix insertion index math

* Allow pointer in double addTo

* fix

* fix ins typetree

* fix erasure
  • Loading branch information
wsmoses authored Nov 27, 2023
1 parent 6732d3d commit 1135f76
Show file tree
Hide file tree
Showing 10 changed files with 361 additions and 133 deletions.
201 changes: 148 additions & 53 deletions enzyme/Enzyme/AdjointGenerator.h
Original file line number Diff line number Diff line change
Expand Up @@ -1084,8 +1084,14 @@ class AdjointGenerator

auto dt = vd[{-1}];
for (size_t i = start; i < size; ++i) {
auto nex = vd[{(int)i}];
if ((nex == BaseType::Anything && dt.isFloat()) ||
(dt == BaseType::Anything && nex.isFloat())) {
nextStart = i;
break;
}
bool Legal = true;
dt.checkedOrIn(vd[{(int)i}], /*PointerIntSame*/ true, Legal);
dt.checkedOrIn(nex, /*PointerIntSame*/ true, Legal);
if (!Legal) {
nextStart = i;
break;
Expand Down Expand Up @@ -1199,7 +1205,8 @@ class AdjointGenerator
Builder2, align, start, size, isVolatile, ordering, syncScope,
mask, prevNoAlias, prevScopes);
((DiffeGradientUtils *)gutils)
->addToDiffe(orig_val, diff, Builder2, FT, start, size, mask);
->addToDiffe(orig_val, diff, Builder2, FT, start, size, {},
mask);
}
break;
}
Expand Down Expand Up @@ -1909,72 +1916,160 @@ class AdjointGenerator

if (!gutils->isConstantValue(orig_inserted)) {
auto TT = TR.query(orig_inserted);
auto it = TT[{-1}];
bool Legal = true;
for (size_t i = 0; i < size0; ++i) {
bool LegalOr = true;
it.checkedOrIn(TT[{(int)i}], /*pointerIntSame*/ true, LegalOr);
Legal &= LegalOr;
}
Type *flt = it.isFloat();
if (!it.isKnown() || !Legal) {
bool found = false;

if (looseTypeAnalysis && !Legal) {
if (orig_inserted->getType()->isFPOrFPVectorTy()) {
flt = orig_inserted->getType()->getScalarType();
found = true;
} else if (orig_inserted->getType()->isIntOrIntVectorTy() ||
orig_inserted->getType()->isPointerTy()) {
flt = nullptr;
found = true;

unsigned start = 0;
Value *dindex = nullptr;

while (1) {
unsigned nextStart = size0;

auto dt = TT[{-1}];
for (size_t i = start; i < size0; ++i) {
auto nex = TT[{(int)i}];
if ((nex == BaseType::Anything && dt.isFloat()) ||
(dt == BaseType::Anything && nex.isFloat())) {
nextStart = i;
break;
}
bool Legal = true;
dt.checkedOrIn(nex, /*PointerIntSame*/ true, Legal);
if (!Legal) {
nextStart = i;
break;
}
}
if (!found) {
std::string str;
raw_string_ostream ss(str);
ss << "Cannot deduce type of insertvalue " << IVI
<< " size: " << size0 << " TT: " << TT.str();
if (CustomErrorHandler) {
CustomErrorHandler(str.c_str(), wrap(&IVI), ErrorType::NoType,
&TR.analyzer, nullptr, wrap(&Builder2));
} else {
EmitFailure("CannotDeduceType", IVI.getDebugLoc(), &IVI,
ss.str());
Type *flt = dt.isFloat();
if (!dt.isKnown()) {
bool found = false;
if (looseTypeAnalysis) {
if (orig_inserted->getType()->isFPOrFPVectorTy()) {
flt = orig_inserted->getType()->getScalarType();
found = true;
} else if (orig_inserted->getType()->isIntOrIntVectorTy() ||
orig_inserted->getType()->isPointerTy()) {
flt = nullptr;
found = true;
}
}
if (!found) {
std::string str;
raw_string_ostream ss(str);
ss << "Cannot deduce type of insertvalue ins " << IVI
<< " size: " << size0 << " TT: " << TT.str();
if (CustomErrorHandler) {
CustomErrorHandler(str.c_str(), wrap(&IVI), ErrorType::NoType,
&TR.analyzer, nullptr, wrap(&Builder2));
} else {
EmitFailure("CannotDeduceType", IVI.getDebugLoc(), &IVI,
ss.str());
}
}
}
}
if (flt) {
auto rule = [&](Value *prediff) {
return Builder2.CreateExtractValue(prediff, IVI.getIndices());
};
auto prediff = diffe(&IVI, Builder2);
auto dindex =
applyChainRule(orig_inserted->getType(), Builder2, rule, prediff);
addToDiffe(orig_inserted, dindex, Builder2, flt);

if (flt) {
if (!dindex) {
auto rule = [&](Value *prediff) {
return Builder2.CreateExtractValue(prediff, IVI.getIndices());
};
auto prediff = diffe(&IVI, Builder2);
dindex = applyChainRule(orig_inserted->getType(), Builder2, rule,
prediff);
}

auto TT = TR.query(orig_inserted);

((DiffeGradientUtils *)gutils)
->addToDiffe(orig_inserted, dindex, Builder2, flt, start,
nextStart - start);
}
if (nextStart == size0)
break;
start = nextStart;
}
}

size_t size1 = 1;
if (orig_agg->getType()->isSized() &&
(orig_agg->getType()->isIntOrIntVectorTy() ||
orig_agg->getType()->isFPOrFPVectorTy()))
if (orig_agg->getType()->isSized())
size1 =
(gutils->newFunc->getParent()->getDataLayout().getTypeSizeInBits(
orig_agg->getType()) +
7) /
8;

if (!gutils->isConstantValue(orig_agg)) {
auto rule = [&](Value *prediff) {
return Builder2.CreateInsertValue(
prediff, Constant::getNullValue(orig_inserted->getType()),
IVI.getIndices());
};
auto prediff = diffe(&IVI, Builder2);
auto dindex =
applyChainRule(orig_agg->getType(), Builder2, rule, prediff);
addToDiffe(orig_agg, dindex, Builder2, TR.addingType(size1, orig_agg));

auto TT = TR.query(orig_agg);

unsigned start = 0;

Value *dindex = nullptr;

while (1) {
unsigned nextStart = size1;

auto dt = TT[{-1}];
for (size_t i = start; i < size1; ++i) {
auto nex = TT[{(int)i}];
if ((nex == BaseType::Anything && dt.isFloat()) ||
(dt == BaseType::Anything && nex.isFloat())) {
nextStart = i;
break;
}
bool Legal = true;
dt.checkedOrIn(nex, /*PointerIntSame*/ true, Legal);
if (!Legal) {
nextStart = i;
break;
}
}
Type *flt = dt.isFloat();
if (!dt.isKnown()) {
bool found = false;
if (looseTypeAnalysis) {
if (orig_agg->getType()->isFPOrFPVectorTy()) {
flt = orig_agg->getType()->getScalarType();
found = true;
} else if (orig_agg->getType()->isIntOrIntVectorTy() ||
orig_agg->getType()->isPointerTy()) {
flt = nullptr;
found = true;
}
}
if (!found) {
std::string str;
raw_string_ostream ss(str);
ss << "Cannot deduce type of insertvalue agg " << IVI
<< " start: " << start << " size: " << size1
<< " TT: " << TT.str();
if (CustomErrorHandler) {
CustomErrorHandler(str.c_str(), wrap(&IVI), ErrorType::NoType,
&TR.analyzer, nullptr, wrap(&Builder2));
} else {
EmitFailure("CannotDeduceType", IVI.getDebugLoc(), &IVI,
ss.str());
}
}
}

if (flt) {
if (!dindex) {
auto rule = [&](Value *prediff) {
return Builder2.CreateInsertValue(
prediff, Constant::getNullValue(orig_inserted->getType()),
IVI.getIndices());
};
auto prediff = diffe(&IVI, Builder2);
dindex =
applyChainRule(orig_agg->getType(), Builder2, rule, prediff);
}
((DiffeGradientUtils *)gutils)
->addToDiffe(orig_agg, dindex, Builder2, flt, start,
nextStart - start);
}
if (nextStart == size1)
break;
start = nextStart;
}
}

setDiffe(&IVI,
Expand Down
Loading

0 comments on commit 1135f76

Please sign in to comment.