Skip to content

Commit

Permalink
Merge branch 'CLAM-1461_4' into 'main'
Browse files Browse the repository at this point in the history
Fix structure address computation and pointer bounds checking.

See merge request clamav/clamav-bytecode-compiler!14
  • Loading branch information
ragusaa committed May 24, 2021
2 parents c2d3c80 + 8dd1f21 commit 515291a
Show file tree
Hide file tree
Showing 3 changed files with 49 additions and 50 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -110,14 +110,8 @@ class ClamBCPrepareGEPsForWriter : public ModulePass
}

if (StructType *pst = llvm::dyn_cast<StructType>(pt)) {

for (size_t i = 0; i < pst->getNumElements(); i++) {
size += getTypeSize(pst->getTypeAtIndex(i));
}

if (size) {
return size;
}
const StructLayout * psl = pMod->getDataLayout().getStructLayout(pst);
return psl->getSizeInBits();
}

assert(0 && "Size has not been computed");
Expand All @@ -138,11 +132,11 @@ class ClamBCPrepareGEPsForWriter : public ModulePass
if (StructType * pst = llvm::dyn_cast<StructType>(pt)){
assert((idx <= pst->getNumElements()) && "Idx too high");

for (uint64_t i = 0; i < idx; i++) {
Type *pt = pst->getElementType(i);
int64_t size = getTypeSizeInBytes(pt);
cnt += size;
}
const StructLayout * psl = pMod->getDataLayout().getStructLayout(pst);
assert (psl && "Could not get layout");

cnt = psl->getElementOffsetInBits(idx)/8;

} else if (ArrayType * pat = llvm::dyn_cast<ArrayType>(pt)){
assert((idx <= pat->getNumElements()) && "Idx too high");
cnt = idx * getTypeSizeInBytes(pat->getElementType());
Expand Down Expand Up @@ -208,7 +202,6 @@ class ClamBCPrepareGEPsForWriter : public ModulePass
if ( ConstantInt * ciIdx = llvm::dyn_cast<ConstantInt>(vIdx)){

uint64_t val = computeOffsetInBytes(currType, ciIdx);
//ConstantInt * ciAddend = ConstantInt::get(ciIdx->getType(), val);
ciAddend = ConstantInt::get(ciIdx->getType(), val);

Type * tmp = findTypeAtIndex(currType, ciIdx);
Expand Down Expand Up @@ -322,14 +315,22 @@ class ClamBCPrepareGEPsForWriter : public ModulePass
pgepi->eraseFromParent();
}

virtual Value* stripBitCasts(Value * pInst){
if (BitCastInst * pbci = llvm::dyn_cast<BitCastInst>(pInst)){
return stripBitCasts(pbci->getOperand(0));
}

return pInst;
}

virtual void processGEPI(GetElementPtrInst * pgepi){

Type * pdst = Type::getInt8Ty(pMod->getContext());

Value * vPtr = pgepi->getPointerOperand();
if (BitCastInst * pbci = llvm::dyn_cast<BitCastInst>(vPtr)){
vPtr = GetUnderlyingObject(pbci, pMod->getDataLayout());
vPtr = stripBitCasts(pbci);

Type * ptrType = vPtr->getType()->getPointerElementType();

if (ArrayType * pat = llvm::dyn_cast<ArrayType>(ptrType)){
Expand All @@ -338,22 +339,18 @@ class ClamBCPrepareGEPsForWriter : public ModulePass
assert (0 && "ClamBCLowering did not do it's job");
}


Type * gepiDstType = pbci->getType()->getPointerElementType();
if (StructType * pst = llvm::dyn_cast<StructType>(gepiDstType)){
processGEPI(pgepi, pbci, vPtr, pst);
} else if (ArrayType * pat = llvm::dyn_cast<ArrayType>(gepiDstType)){
processGEPI(pgepi, pbci, vPtr, pat);
} else {
DEBUGERR << *gepiDstType << "<END>\n";
}

} else {
assert (0 && "FIGURE OUT IF I NEED TO DO ANYTHING HERE?");
}
}


virtual void convertArrayStructGEPIsToI8(Function * pFunc){
std::vector<GetElementPtrInst * > gepis;
for (auto i = pFunc->begin(), e = pFunc->end(); i != e; i++){
Expand Down
36 changes: 31 additions & 5 deletions libclambcc/ClamBCRemoveUndefs/ClamBCRemoveUndefs.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -46,15 +46,13 @@ class ClamBCRemoveUndefs : public ModulePass
{
Function *pFunc = BB->getParent();

//BasicBlock abrt = std::find(pFunc, aborts.begin(), aborts.end());
auto iter = aborts.find(pFunc);
if (aborts.end() != iter) {
return iter->second;
}

FunctionType *abrtTy = FunctionType::get(
Type::getVoidTy(BB->getContext()), false);
//args.push_back(Type::getInt32Ty(BB->getContext()));
FunctionType *rterrTy = FunctionType::get(
Type::getInt32Ty(BB->getContext()),
{Type::getInt32Ty(BB->getContext())}, false);
Expand All @@ -63,8 +61,6 @@ class ClamBCRemoveUndefs : public ModulePass
Constant *func_rterr =
BB->getParent()->getParent()->getOrInsertFunction("bytecode_rt_error", rterrTy);
BasicBlock *abort = BasicBlock::Create(BB->getContext(), "rterr.trig", BB->getParent());
// PHINode * PN = PHINode::Create(Type::getInt32Ty(BB->getContext()), 0, "ClamBCRTChecks_abort",
// abort);
Constant *PN = ConstantInt::get(Type::getInt32Ty(BB->getContext()), 99);
if (MDDbgKind) {
CallInst *RtErrCall = CallInst::Create(func_rterr, PN, "", abort);
Expand Down Expand Up @@ -124,6 +120,34 @@ class ClamBCRemoveUndefs : public ModulePass

delLst.push_back(term);
bChanged = true;

}

virtual bool isSamePointer(Value * ptr1, Value * ptr2, std::set<llvm::Value *> &visited) {

if (visited.end() != std::find(visited.begin(), visited.end(), ptr1)) {
return false;
}
visited.insert(ptr1);

if (ptr1 == ptr2){
return true;
}

if (User * pu = llvm::dyn_cast<User>(ptr1)){

for (size_t i = 0; i < pu->getNumOperands(); i++){
if (isSamePointer(pu->getOperand(i), ptr2, visited)){
return true;
}
}
}
return false;
}

virtual bool isSamePointer(Value * ptr1, Value * ptr2) {
std::set<llvm::Value *> visited;
return isSamePointer(ptr1, ptr2, visited);
}

virtual void insertChecks(Value *ptr, Value *size)
Expand All @@ -135,7 +159,9 @@ class ClamBCRemoveUndefs : public ModulePass

for (auto i : insts) {
if (GetElementPtrInst *pgepi = llvm::dyn_cast<GetElementPtrInst>(i)) {
insertChecks(pgepi, size);
if (isSamePointer(pgepi->getPointerOperand(), ptr)){
insertChecks(pgepi, size);
}
}
}
}
Expand Down
28 changes: 2 additions & 26 deletions libclambcc/Common/ClamBCRegAlloc.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -60,22 +60,14 @@ void ClamBCRegAlloc::handlePHI(PHINode *PN)
unsigned MDDbgKind = PN->getContext().getMDKindID("dbg");
if (MDDbgKind) {
if (MDNode *Dbg = PN->getMetadata(MDDbgKind)) {
#if 0
builder.SetCurrentDebugLocation(Dbg);
#else
DebugLoc dl(Dbg);
builder.SetCurrentDebugLocation(dl);
#endif
}
}
for (unsigned i = 0; i < PN->getNumIncomingValues(); i++) {
BasicBlock *BB = PN->getIncomingBlock(i);
Value *V = PN->getIncomingValue(i);
#if 0
builder.SetInsertPoint(BB, BB->getTerminator());
#else
builder.SetInsertPoint(BB->getTerminator());
#endif
Instruction *I = builder.CreateStore(V, AI);
builder.SetInstDebugLocation(I);
}
Expand Down Expand Up @@ -143,9 +135,10 @@ bool ClamBCRegAlloc::runOnFunction(Function &F)
const PointerType *SPTy, *DPTy;
while ((SPTy = dyn_cast<PointerType>(SrcTy))) {
DPTy = dyn_cast<PointerType>(DstTy);
if (!DPTy)
if (!DPTy) {
ClamBCStop("Cast from pointer to non-pointer element",
BCI);
}
SrcTy = SPTy->getElementType();
DstTy = DPTy->getElementType();
}
Expand All @@ -166,13 +159,6 @@ bool ClamBCRegAlloc::runOnFunction(Function &F)
ValueMap[II] = getValueID(II->getOperand(0));
continue;
}
#if 0
if (isa<PtrToIntInst>(BC)) {
// sub ptrtoint, ptrtoint is supported
SkipMap.insert(II);
continue;
}
#endif
}
if (II->hasOneUse()) {
// single-use store to alloca -> store directly to alloca
Expand Down Expand Up @@ -265,18 +251,8 @@ unsigned ClamBCRegAlloc::buildReverseMap(std::vector<const Value *> &reverseMap)

void ClamBCRegAlloc::getAnalysisUsage(AnalysisUsage &AU) const
{
//AU.addRequired<LiveValues>();
AU.addRequired<DominatorTreeWrapperPass>();

#if 0
// We promise not to introduce anything that is unsafe.
// If the verifier accepted the bytecode so far, we don't break it.
// This is needed because we can't rerun the verifier: it can only
// analyze bytecode in SSA form, and we intentionally break SSA form here
// (we eliminate PHIs).
AU.addPreservedID(ClamBCVerifierID);
#endif

// Preserve the CFG, we only eliminate PHIs, and introduce some
// loads/stores.
AU.setPreservesCFG();
Expand Down

0 comments on commit 515291a

Please sign in to comment.