From 2a55340555f4a1924f8902bc758829fbb4d32014 Mon Sep 17 00:00:00 2001 From: Andy Ragusa Date: Fri, 4 Jun 2021 14:23:30 -0700 Subject: [PATCH] Handle PHINodes with all constants in ClamBCLogicalCompiler. Previously, ClamBCLogicalCompiler expected at least one possible value of a PHINode to be variable, and not constant while processing PHINodes. This has been fixed by finding all possible values that could be TRUE and ORing them together. --- .../ClamBCLogicalCompiler.cpp | 328 ++++++++++-------- 1 file changed, 191 insertions(+), 137 deletions(-) diff --git a/libclambcc/ClamBCLogicalCompiler/ClamBCLogicalCompiler.cpp b/libclambcc/ClamBCLogicalCompiler/ClamBCLogicalCompiler.cpp index ff4534fa63..56d748f0f3 100644 --- a/libclambcc/ClamBCLogicalCompiler/ClamBCLogicalCompiler.cpp +++ b/libclambcc/ClamBCLogicalCompiler/ClamBCLogicalCompiler.cpp @@ -61,6 +61,7 @@ using namespace llvm; namespace { + class ClamBCLogicalCompiler : public ModulePass { public: @@ -88,6 +89,7 @@ class ClamBCLogicalCompiler : public ModulePass bool validateVirusName(const std::string &name, Module &M, bool suffix = false); bool compileVirusNames(Module &M, unsigned kind); }; + char ClamBCLogicalCompiler::ID = 0; RegisterPass X("clambc-lcompiler", "ClamAV Logical Compiler"); @@ -686,177 +688,224 @@ class LogicalCompiler return true; } - LogicalNode *generateLogicalPHI(PHINode *phi, LogicalNode *ln, size_t idx, std::vector &nodes) - { - Module *pMod = phi->getParent()->getParent()->getParent(); - ConstantInt *ciFalse = ConstantInt::getFalse(pMod->getContext()); - ConstantInt *ciTrue = ConstantInt::getTrue(pMod->getContext()); - BasicBlock *incoming = phi->getIncomingBlock(idx); - LogicalNode *newNode = nullptr; + class LogicalPHIHelper { + public: - bool matchedAnything = false; - for (size_t i = 0; i < phi->getNumIncomingValues(); i++) { - if (i == idx) { - continue; + LogicalPHIHelper (BranchInst * bi, bool isTrue) { + pBranchInst = bi; + pBasicBlock = bi->getParent(); + this->isTrue = isTrue; + } + + LogicalPHIHelper(LogicalPHIHelper * lph){ + this->pBasicBlock = lph->pBasicBlock; + this->pBranchInst = lph->pBranchInst; + this->isTrue = lph->isTrue; + } + + virtual ~LogicalPHIHelper(){} + + BranchInst * getBranchInst() { return pBranchInst; } + + BasicBlock * getBasicBlock() { return pBasicBlock; } + + bool getIsTrue() { + return isTrue; + } + + Value * getCondition(){ + if (pBranchInst->isConditional()){ + return pBranchInst->getCondition(); } + return nullptr; + } - BasicBlock *bb = phi->getIncomingBlock(i); - Value *v = bb->getTerminator(); - BranchInst *bi = llvm::cast(v); - Value *incomingValue = phi->getIncomingValue(i); - - bool foundMatch = false; - if (bi->getSuccessor(0) == incoming) { - //pair these blocks - foundMatch = true; - Value *v = incomingValue; - if (bi->isConditional()) { - v = bi->getCondition(); - } - LogicalMap::iterator iter = Map.find(v); - if (Map.end() == iter) { - assert(0 && "HOW DID THIS HAPPEN?"); - } - LogicalNode *tmp = nodes[i]; + protected: + BasicBlock * pBasicBlock; - if (ciFalse == incomingValue) { - newNode = LogicalNode::getAnd(ln, tmp); - } else if (ciTrue == incomingValue) { - newNode = LogicalNode::getOr(ln, tmp); - } else { - assert(0 && "HANDLE THIS CASE"); - } - } else if (bi->isConditional()) { - if (bi->getSuccessor(1) == incoming) { - //pairt these blocks + BranchInst * pBranchInst; - Value *v = incomingValue; - if (bi->isConditional()) { - v = bi->getCondition(); - } + bool isTrue; - LogicalMap::iterator iter = Map.find(v); - if (Map.end() == iter) { - assert(0 && "HOW DID THIS HAPPEN?"); - } - //LogicalNode * tmp = iter->second; - LogicalNode *tmp = nodes[i]; + }; - if (ciFalse == incomingValue) { - newNode = LogicalNode::getAnd(ln, tmp); - //return newNode; - } else if (ciTrue == incomingValue) { - tmp = LogicalNode::getNot(tmp); - newNode = LogicalNode::getOr(ln, tmp); - } else { - assert(0 && "HANDLE THIS CASE"); - } - foundMatch = true; - } + /*Generate all paths from the 'curr' to 'end' and store them in routes.*/ + void populateRoutes(BasicBlock * curr, BasicBlock * end, std::vector> & routes, size_t idx){ + + if (curr == end){ + return; + } + + for (size_t i = 0; i < routes[idx].size(); i++){ + if (routes[idx][i]->getBranchInst() == curr->getTerminator()){ + return ; } - if (foundMatch) { - matchedAnything = true; - LogicalNode *tmp = generateLogicalPHI(phi, newNode, i, nodes); - if (nullptr != tmp) { - return tmp; + } + + if (BranchInst * bi = llvm::dyn_cast(curr->getTerminator())){ + if (bi->isConditional()){ + //copy the route, so that there are separate paths for the true + //and false condition. + std::vector route; + for (size_t i = 0; i < routes[idx].size(); i++){ + route.push_back(new LogicalPHIHelper(routes[idx][i])); } + routes.push_back(route); + size_t falseIdx = routes.size()-1; + + routes[idx].push_back(new LogicalPHIHelper(bi, true)); + routes[falseIdx].push_back(new LogicalPHIHelper(bi, false)); + + populateRoutes(bi->getSuccessor(0), end, routes, idx); + populateRoutes(bi->getSuccessor(1), end, routes, falseIdx); + + } else { + routes[idx].push_back(new LogicalPHIHelper(bi, true)); + populateRoutes(bi->getSuccessor(0), end, routes, idx); } - } - if (not matchedAnything) { - return ln; } - return nullptr; } - void generateLogicalPHI(PHINode *phi, std::vector &nodes) - { - for (size_t i = 0; i < phi->getNumIncomingValues(); i++) { - llvm::Value *pIncoming = phi->getIncomingValue(i); - if (not isa(pIncoming)) { - LogicalNode *newNode = generateLogicalPHI(phi, nodes[i], i, nodes); - Map[phi] = newNode; - return; + + /* Find all routes from the entry BasicBlock that end with 'pBasicBlock' */ + std::vector findRoute(BasicBlock * pBasicBlock, std::vector> & routes){ + std::vector ret; + for (size_t i = 0; i < routes.size(); i++){ + size_t lastIdx = routes[i].size()-1; + if (routes[i][lastIdx]->getBasicBlock() == pBasicBlock){ + ret.push_back(i); } } + return ret; } - void saveNodes(PHINode *phiNode, std::vector &nodes) - { + LogicalNode * getLogicalNode(std::vector & route){ + LogicalNode * ret = nullptr; - llvm::Module *pMod = phiNode->getParent()->getParent()->getParent(); - - for (size_t i = 0; i < phiNode->getNumIncomingValues(); i++) { - Value *pVal = phiNode->getIncomingValue(i); - BasicBlock *pIncomingBlock = phiNode->getIncomingBlock(i); - Instruction *pTerminator = pIncomingBlock->getTerminator(); - assert(llvm::isa(pTerminator) && "HOW IS THIS POSSIBLE?"); - BranchInst *pTerminatorBranch = llvm::cast(pTerminator); - - if (pVal == ConstantInt::getTrue(pMod->getContext())) { - Value *pCond = pTerminatorBranch->getCondition(); - - LogicalMap::iterator condIter = Map.find(pCond); - if (Map.end() == condIter) { - /*This should be impossible, but insert meaningful error message here.*/ - assert(0 && "HOW DID THIS HAPPEN?"); + for (size_t i = 0; i < route.size(); i++){ + Value * vCond = route[i]->getCondition(); + if (vCond){ + LogicalNode * ln = Map.find(vCond)->second; + if (not route[i]->getIsTrue()){ + ln = LogicalNode::getNot(ln); } - LogicalNode *node = nullptr; - if (phiNode->getParent() == pTerminatorBranch->getSuccessor(0)) { - /* if the condition is true, we branch to the phi block and return true. */ - node = condIter->second; - } else if (phiNode->getParent() == pTerminatorBranch->getSuccessor(1)) { - /* if the condition is false, we branch to the phi block and return true. */ - node = LogicalNode::getNot(condIter->second); + if (nullptr == ret){ + ret = ln; } else { - assert(0 && "HOW IS THIS POSSIBLE"); + ret = LogicalNode::getAnd(ret, ln); } - nodes.push_back(node); - } else if (pVal == ConstantInt::getFalse(pMod->getContext())) { - assert(pTerminatorBranch->isConditional() && "SHOULD NEVER HAPPEN"); + } + } + + return ret; + + } + + /* + * Our method for processing a phi node is to find all possible paths to a phi node + * that could generate 'true' and Or them together. + * + * For example: Consider the following function. + * + ; Function Attrs: noinline norecurse nounwind readnone uwtable + define dso_local zeroext i1 @logical_trigger() local_unnamed_addr #0 { + entry: + %0 = load i32, i32* getelementptr inbounds ([64 x i32], [64 x i32]* @__clambc_match_counts, i64 0, i64 0), align 16 + %cmp.i = icmp eq i32 %0, 0 + %1 = load i32, i32* getelementptr inbounds ([64 x i32], [64 x i32]* @__clambc_match_counts, i64 0, i64 1), align 4 + %cmp.i53 = icmp eq i32 %1, 0 + %or.cond = or i1 %cmp.i, %cmp.i53 + br i1 %or.cond, label %return, label %if.end + + if.end: ; preds = %entry + %2 = load i32, i32* getelementptr inbounds ([64 x i32], [64 x i32]* @__clambc_match_counts, i64 0, i64 2), align 8 + %cmp.i47 = icmp eq i32 %2, 0 + br i1 %cmp.i47, label %if.else, label %if.then5 + + if.then5: ; preds = %if.end + %3 = load i32, i32* getelementptr inbounds ([64 x i32], [64 x i32]* @__clambc_match_counts, i64 0, i64 4), align 16 + %cmp.i41 = icmp eq i32 %3, 0 + br i1 %cmp.i41, label %if.end20, label %return + + if.else: ; preds = %if.end + %4 = load i32, i32* getelementptr inbounds ([64 x i32], [64 x i32]* @__clambc_match_counts, i64 0, i64 3), align 4 + %cmp.i35 = icmp eq i32 %4, 0 + br i1 %cmp.i35, label %return, label %if.then12 + + if.then12: ; preds = %if.else + %5 = load i32, i32* getelementptr inbounds ([64 x i32], [64 x i32]* @__clambc_match_counts, i64 0, i64 4), align 16 + %cmp = icmp ne i32 %5, 2 + %6 = load i32, i32* getelementptr inbounds ([64 x i32], [64 x i32]* @__clambc_match_counts, i64 0, i64 5), align 4 + %cmp.i25 = icmp eq i32 %6, 0 + %or.cond1 = or i1 %cmp, %cmp.i25 + br i1 %or.cond1, label %if.end20, label %return + + if.end20: ; preds = %if.then12, %if.then5 + br label %return + + return: ; preds = %if.else, %if.then12, %if.then5, %entry, %if.end20 + %retval.0 = phi i1 [ true, %if.end20 ], [ false, %entry ], [ false, %if.then5 ], [ false, %if.then12 ], [ false, %if.else ] + ret i1 %retval.0 + } + + + The phi node is + %retval.0 = phi i1 [ true, %if.end20 ], [ false, %entry ], [ false, %if.then5 ], [ false, %if.then12 ], [ false, %if.else ] + + This can only return true if the %return block is entered from the %if.end20 block. There are two possible cases + for that to happen, which will be OR'd together. - Value *pCond = pTerminatorBranch->getCondition(); + The logical expression for this PHINode is + (%or.cond1 AND (NOT %cmp.i35) AND %cmp.i47 AND (NOT %or.cond)) OR (%cmp.i41 AND (NOT %cmp.i47) AND (NOT %or.cond)) + * + */ + void processPHI(PHINode * pn){ + BasicBlock * phiBlock = pn->getParent(); + BasicBlock * startBlock = llvm::cast(pn->getParent()->getParent()->begin()); - LogicalMap::iterator condIter = Map.find(pCond); - if (Map.end() == condIter) { - /*This should be impossible, but insert meaningful error message here.*/ - assert(0 && "HOW DID THIS HAPPEN?"); + std::vector> routes; + std::vector route; + routes.push_back(route); + populateRoutes(startBlock, phiBlock, routes, 0); + + LogicalNode * ln = nullptr; + + for (size_t i = 0; i < pn->getNumIncomingValues(); i++){ + Value * vIncoming = pn->getIncomingValue(i); + ConstantInt * pci = llvm::dyn_cast(vIncoming); + if (pci){ + if (pci->isZero()){ + continue; } - LogicalNode *node = nullptr; - if (phiNode->getParent() == pTerminatorBranch->getSuccessor(0)) { - /* if the condition is true, we branch to the phi block and return false. */ - node = LogicalNode::getNot(condIter->second); - } else if (phiNode->getParent() == pTerminatorBranch->getSuccessor(1)) { - /* if the condition is false, we branch to the phi block and return false. */ - node = condIter->second; - } else { - assert(0 && "HOW IS THIS POSSIBLE?"); + } + std::vector idxs = findRoute(pn->getIncomingBlock(i), routes); + for (size_t j = 0; j < idxs.size(); j++) { + size_t idx = idxs[j]; + LogicalNode * tmp = getLogicalNode(routes[idx]); + if (nullptr == pci){ //Then this isn't a constant + LogicalNode * l = Map.find(vIncoming)->second; + tmp = LogicalNode::getAnd(tmp, l); } - nodes.push_back(node); - } else { - LogicalMap::iterator condIter = Map.find(pVal); - if (Map.end() == condIter) { - assert(0 && "HOW DID THIS HAPPEN?"); + if (nullptr == ln){ + ln = tmp; + } else { + ln = LogicalNode::getOr(ln, tmp); } - - nodes.push_back(condIter->second); } } - } - - bool processPHI(PHINode *phiNode) - { - - std::vector nodes; - saveNodes(phiNode, nodes); + Map[pn] = ln; - generateLogicalPHI(phiNode, nodes); - return true; + for (size_t i = 0; i < routes.size(); i++){ + for (size_t j = 0; j < routes[i].size(); j++){ + delete(routes[i][j]); + } + } } bool processBB(BasicBlock *BB) @@ -1580,7 +1629,6 @@ bool ClamBCLogicalCompiler::compileVirusNames(Module &M, unsigned kind) pv = llvm::cast(I); CallSite CS(pv); if (!CS.getInstruction()) { - DEBUGERR << "\n"; continue; } if (CS.getCalledFunction() != F) { @@ -1624,6 +1672,7 @@ bool ClamBCLogicalCompiler::compileVirusNames(Module &M, unsigned kind) return Valid; } + bool ClamBCLogicalCompiler::runOnModule(Module &M) { bool Valid = true; @@ -1631,6 +1680,11 @@ bool ClamBCLogicalCompiler::runOnModule(Module &M) virusnames = ""; pMod = &M; + + //dumpPHIGraphs(); + + + // Handle virusname unsigned kind = 0; GlobalVariable *GVKind = M.getGlobalVariable("__clambc_kind");