Skip to content

Commit

Permalink
[c2cpg] Improvements for range-based for-statement and local code fie…
Browse files Browse the repository at this point in the history
…lds (#5215)

- range-based for-statement blocks had multiple outgoing CFG edges because the children were not wrapped into blocks.
- code fields for locals from NamedTypeSpecifiers now preserve static and const modifier
  • Loading branch information
max-leuthaeuser authored Jan 9, 2025
1 parent 166cfab commit 97806c7
Show file tree
Hide file tree
Showing 7 changed files with 153 additions and 113 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -142,14 +142,16 @@ trait AstCreatorHelper(implicit withSchemaValidation: ValidationMode) { this: As
"class"
)

private val KeepTypeKeywords: List[String] = List("unsigned", "volatile")
private val KeepTypeKeywords: List[String] = List("unsigned", "volatile", "const", "static")

protected def cleanType(rawType: String, stripKeywords: Boolean = true): String = {
if (rawType == Defines.Any) return rawType
val tpe =
if (stripKeywords) {
ReservedTypeKeywords.foldLeft(rawType) { (cur, repl) =>
if (cur.contains(s"$repl ")) cur.replace(s"$repl ", "") else cur
if (cur.startsWith(s"$repl ") || cur.contains(s" $repl ")) {
cur.replace(s" $repl ", " ").replace(s"$repl ", "")
} else cur
}
} else {
rawType
Expand All @@ -168,17 +170,23 @@ trait AstCreatorHelper(implicit withSchemaValidation: ValidationMode) { this: As
case t if t.contains(Defines.QualifiedNameSeparator) => replaceWhitespaceAfterTypeKeyword(fixQualifiedName(t))
case t if KeepTypeKeywords.exists(k => t.startsWith(s"$k ")) => replaceWhitespaceAfterTypeKeyword(t)
case t if t.contains("[") && t.contains("]") => replaceWhitespaceAfterTypeKeyword(t)
case t if t.contains("<") && t.contains(">") => replaceWhitespaceAfterTypeKeyword(t)
case t if t.contains("*") => replaceWhitespaceAfterTypeKeyword(t)
case someType => someType
}
}

private def replaceWhitespaceAfterTypeKeyword(tpe: String): String = {
if (KeepTypeKeywords.exists(k => tpe.startsWith(s"$k "))) {
if (KeepTypeKeywords.exists(k => tpe.startsWith(s"$k ") || tpe.contains(s" $k "))) {
KeepTypeKeywords.foldLeft(tpe) { (cur, repl) =>
val prefix = s"$repl "
if (cur.startsWith(prefix)) {
prefix + cur.substring(prefix.length).replace(" ", "")
val prefixStartsWith = s"$repl "
val prefixContains = s" $repl "
if (cur.startsWith(prefixStartsWith)) {
prefixStartsWith + replaceWhitespaceAfterTypeKeyword(cur.substring(prefixStartsWith.length))
} else if (cur.contains(prefixContains)) {
val front = tpe.substring(0, tpe.indexOf(prefixContains))
val back = tpe.substring(tpe.indexOf(prefixContains) + prefixContains.length)
s"${replaceWhitespaceAfterTypeKeyword(front)}$prefixContains${replaceWhitespaceAfterTypeKeyword(back)}"
} else {
cur
}
Expand Down Expand Up @@ -324,7 +332,7 @@ trait AstCreatorHelper(implicit withSchemaValidation: ValidationMode) { this: As
}

private def typeForCPPAstNamedTypeSpecifier(s: ICPPASTNamedTypeSpecifier, stripKeywords: Boolean): String = {
val tpe = safeGetBinding(s).map(_.toString.replace(" ", "")).getOrElse(ASTStringUtil.getReturnTypeString(s, null))
val tpe = safeGetBinding(s).map(_.toString).getOrElse(s.getRawSignature)
cleanType(tpe, stripKeywords)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -308,14 +308,13 @@ trait AstForStatementsCreator(implicit withSchemaValidation: ValidationMode) { t
val code = s"for ($codeInit$codeCond;$codeIter)"
val forNode = controlStructureNode(forStmt, ControlStructureTypes.FOR, code)

val initAstBlock = blockNode(forStmt, Defines.Empty, registerType(Defines.Void)).order(1)
scope.pushNewScope(initAstBlock)
val initAst = blockAst(initAstBlock, nullSafeAst(forStmt.getInitializerStatement).toList)
val compareAst = astForConditionExpression(forStmt.getConditionExpression, Option(2))
val updateAst = nullSafeAst(forStmt.getIterationExpression, 3)
val bodyAsts = nullSafeAst(forStmt.getBody, 4)
scope.popScope()
forAst(forNode, Seq.empty, Seq(initAst), Seq(compareAst), Seq(updateAst), bodyAsts)
val (localAsts, initAsts) =
nullSafeAst(forStmt.getInitializerStatement).partition(_.root.exists(_.isInstanceOf[NewLocal]))
setArgumentIndices(initAsts)
val compareAst = astForConditionExpression(forStmt.getConditionExpression)
val updateAst = nullSafeAst(forStmt.getIterationExpression)
val bodyAsts = nullSafeAst(forStmt.getBody)
forAst(forNode, localAsts, initAsts, Seq(compareAst), Seq(updateAst), bodyAsts)
}

private def astForRangedFor(forStmt: ICPPASTRangeBasedForStatement): Ast = {
Expand All @@ -325,14 +324,18 @@ trait AstForStatementsCreator(implicit withSchemaValidation: ValidationMode) { t
val forNode = controlStructureNode(forStmt, ControlStructureTypes.FOR, code)
forStmt.getDeclaration match {
case declaration: ICPPASTStructuredBindingDeclaration =>
val initAsts = astsForStructuredBindingDeclaration(declaration, Some(forStmt.getInitializerClause))
val bodyAsts = nullSafeAst(forStmt.getBody, 4)
controlStructureAst(forNode, None, (initAsts ++ bodyAsts).toList)
val (localAsts, initAsts) = astsForStructuredBindingDeclaration(declaration, Some(forStmt.getInitializerClause))
.partition(_.root.exists(_.isInstanceOf[NewLocal]))
setArgumentIndices(initAsts)
val bodyAst = nullSafeAst(forStmt.getBody)
forAst(forNode, localAsts, initAsts.filterNot(_.nodes.isEmpty), Seq.empty, Seq.empty, bodyAst)
case _ =>
val initAst = astForNode(forStmt.getInitializerClause)
val declAst = astsForDeclaration(forStmt.getDeclaration)
val stmtAst = nullSafeAst(forStmt.getBody)
controlStructureAst(forNode, None, Seq(initAst) ++ declAst ++ stmtAst)
val init = astForNode(forStmt.getInitializerClause)
val declAsts = astsForDeclaration(forStmt.getDeclaration)
setArgumentIndices(init +: declAsts)
val (localAsts, initAsts) = (init +: declAsts).partition(_.root.exists(_.isInstanceOf[NewLocal]))
val bodyAst = nullSafeAst(forStmt.getBody)
forAst(forNode, localAsts, initAsts.filterNot(_.nodes.isEmpty), Seq.empty, Seq.empty, bodyAst)
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -700,15 +700,16 @@ class AstCreationPassTests extends AstC2CpgSuite {
)
inside(cpg.method.nameExact("method").controlStructure.l) { case List(forStmt) =>
forStmt.controlStructureType shouldBe ControlStructureTypes.FOR
inside(forStmt.astChildren.order(1).l) { case List(ident: Identifier) =>
ident.code shouldBe "list"
}
inside(forStmt.astChildren.order(2).l) { case List(x: Local) =>
inside(forStmt.astChildren.isLocal.l) { case List(x: Local) =>
x.name shouldBe "x"
x.typeFullName shouldBe "int"
x.code shouldBe "int x"
}
inside(forStmt.astChildren.order(3).l) { case List(block: Block) =>
// for the expected orders see CfgCreator.cfgForForStatement
inside(forStmt.astChildren.order(2).l) { case List(ident: Identifier) =>
ident.code shouldBe "list"
}
inside(forStmt.astChildren.order(5).l) { case List(block: Block) =>
block.astChildren.isCall.code.l shouldBe List("z = x")
}
}
Expand All @@ -726,7 +727,7 @@ class AstCreationPassTests extends AstC2CpgSuite {
)
inside(cpg.method.nameExact("method").controlStructure.l) { case List(forStmt) =>
forStmt.controlStructureType shouldBe ControlStructureTypes.FOR
forStmt.astChildren.isCall.code.l shouldBe List(
forStmt.astChildren.isBlock.astChildren.isCall.code.l shouldBe List(
"anonymous_tmp_0 = foo",
"a = anonymous_tmp_0[0]",
"b = anonymous_tmp_0[1]"
Expand Down Expand Up @@ -819,7 +820,7 @@ class AstCreationPassTests extends AstC2CpgSuite {
""".stripMargin)
val List(forLoop) = cpg.controlStructure.l
val List(conditionBlock) = forLoop.condition.collectAll[Block].l
conditionBlock.argumentIndex shouldBe 2
conditionBlock.order shouldBe 2
val List(assignmentCall, greaterCall) = conditionBlock.astChildren.collectAll[Call].l
assignmentCall.argumentIndex shouldBe 1
assignmentCall.code shouldBe "b = something()"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -92,26 +92,26 @@ class ControlStructureTests extends C2CpgSuite(FileDefaults.CppExt) {
"should be correct for for-loop with multiple assignments" in {
inside(cpg.controlStructure.l) { case List(forLoop) =>
forLoop.controlStructureType shouldBe ControlStructureTypes.FOR
inside(forLoop.astChildren.order(1).l) { case List(assignmentBlock) =>
inside(assignmentBlock.astChildren.l) { case List(localX, localY, assignmentX, assignmentY) =>
localX.code shouldBe "int x"
localX.order shouldBe 1
localY.code shouldBe "int y"
localY.order shouldBe 2
inside(forLoop.astChildren.isLocal.l) { case List(localX, localY) =>
localX.code shouldBe "int x"
localY.code shouldBe "int y"
}
inside(forLoop.astChildren.order(3).l) { case List(assignmentBlock) =>
inside(assignmentBlock.astChildren.l) { case List(assignmentX, assignmentY) =>
assignmentX.code shouldBe "x=1"
assignmentX.order shouldBe 3
assignmentX.order shouldBe 1
assignmentY.code shouldBe "y=1"
assignmentY.order shouldBe 4
assignmentY.order shouldBe 2
}
}
inside(forLoop.condition.l) { case List(x) =>
x.code shouldBe "x"
x.order shouldBe 2
x.order shouldBe 4
}
inside(forLoop.astChildren.order(3).l) { case List(updateX) =>
inside(forLoop.astChildren.order(5).l) { case List(updateX) =>
updateX.code shouldBe "--x"
}
inside(forLoop.astChildren.order(4).l) { case List(loopBody) =>
inside(forLoop.astChildren.order(6).l) { case List(loopBody) =>
loopBody.astChildren.isCall.head.code shouldBe "bar()"
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ class TemplateTypeTests extends C2CpgSuite(fileSuffix = FileDefaults.CppExt) {
typeDeclA.aliasTypeFullName shouldBe Option("X<int>")
typeDeclB.name shouldBe "B"
typeDeclB.fullName shouldBe "B"
typeDeclB.aliasTypeFullName shouldBe Option("Y<int, char>")
typeDeclB.aliasTypeFullName shouldBe Option("Y<int,char>")
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,79 +7,97 @@ import io.shiftleft.semanticcpg.language.*
*/
class LocalQueryTests extends C2CpgSuite {

private val cpg = code("""
| struct node {
| int value;
| struct node *next;
| };
|
| void free_list(struct node *head) {
| struct node *q;
| for (struct node *p = head; p != NULL; p = q) {
| q = p->next;
| free(p);
| }
| }
|
| int flow(int p0) {
| int a = p0;
| int b = a;
| int c = 0x31;
| int z = b + c;
| z++;
| int x = z;
| return x;
| }
|
| void test() {
| static int a, b, c;
| wchar_t *foo;
| int d[10], e = 1;
| }
| """.stripMargin)

"should allow to query for all locals" in {
cpg.local.name.toSetMutable shouldBe Set("a", "b", "c", "e", "d", "z", "x", "q", "p", "foo")
"local query example 1" should {
"allow to query for the local" in {
val cpg = code(
"""
|void foo() {
| static const Foo::Bar bar{};
|}
|""".stripMargin,
"test.cpp"
)
val List(barLocal) = cpg.method.name("foo").local.l
barLocal.name shouldBe "bar"
barLocal.typeFullName shouldBe "Foo.Bar"
barLocal.code shouldBe "static const Foo.Bar bar"
}
}

"should prove correct (name, type) pairs for locals" in {
inside(cpg.method.name("free_list").local.l) { case List(q, p) =>
q.name shouldBe "q"
q.typeFullName shouldBe "node*"
q.code shouldBe "struct node* q"
p.name shouldBe "p"
p.typeFullName shouldBe "node*"
p.code shouldBe "struct node* p"
"local query example 2" should {
val cpg = code("""
| struct node {
| int value;
| struct node *next;
| };
|
| void free_list(struct node *head) {
| struct node *q;
| for (struct node *p = head; p != NULL; p = q) {
| q = p->next;
| free(p);
| }
| }
|
| int flow(int p0) {
| int a = p0;
| int b = a;
| int c = 0x31;
| int z = b + c;
| z++;
| int x = z;
| return x;
| }
|
| void test() {
| static int a, b, c;
| wchar_t *foo;
| int d[10], e = 1;
| }
| """.stripMargin)

"should allow to query for all locals" in {
cpg.local.name.toSetMutable shouldBe Set("a", "b", "c", "e", "d", "z", "x", "q", "p", "foo")
}
}

"should prove correct (name, type, code) pairs for locals" in {
inside(cpg.method.name("test").local.l) { case List(a, b, c, foo, d, e) =>
a.name shouldBe "a"
a.typeFullName shouldBe "int"
a.code shouldBe "static int a"
b.name shouldBe "b"
b.typeFullName shouldBe "int"
b.code shouldBe "static int b"
c.name shouldBe "c"
c.typeFullName shouldBe "int"
c.code shouldBe "static int c"
foo.name shouldBe "foo"
foo.typeFullName shouldBe "wchar_t*"
foo.code shouldBe "wchar_t* foo"
d.name shouldBe "d"
d.typeFullName shouldBe "int[10]"
d.code shouldBe "int[10] d"
e.name shouldBe "e"
e.typeFullName shouldBe "int"
e.code shouldBe "int e"
"should prove correct (name, type) pairs for locals" in {
inside(cpg.method.name("free_list").local.l) { case List(q, p) =>
q.name shouldBe "q"
q.typeFullName shouldBe "node*"
q.code shouldBe "struct node* q"
p.name shouldBe "p"
p.typeFullName shouldBe "node*"
p.code shouldBe "struct node* p"
}
}
}

"should allow finding filenames by local regex" in {
val filename = cpg.local.name("a*").file.name.headOption
filename should not be empty
filename.head.endsWith(".c") shouldBe true
}
"should prove correct (name, type, code) pairs for locals" in {
inside(cpg.method.name("test").local.l) { case List(a, b, c, foo, d, e) =>
a.name shouldBe "a"
a.typeFullName shouldBe "int"
a.code shouldBe "static int a"
b.name shouldBe "b"
b.typeFullName shouldBe "int"
b.code shouldBe "static int b"
c.name shouldBe "c"
c.typeFullName shouldBe "int"
c.code shouldBe "static int c"
foo.name shouldBe "foo"
foo.typeFullName shouldBe "wchar_t*"
foo.code shouldBe "wchar_t* foo"
d.name shouldBe "d"
d.typeFullName shouldBe "int[10]"
d.code shouldBe "int[10] d"
e.name shouldBe "e"
e.typeFullName shouldBe "int"
e.code shouldBe "int e"
}
}

"should allow finding filenames by local regex" in {
val filename = cpg.local.name("a*").file.name.headOption
filename should not be empty
filename.head.endsWith(".c") shouldBe true
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,13 @@ abstract class AstCreatorBase(filename: String)(implicit withSchemaValidation: V
): Ast =
forAst(forNode, locals, initAsts, conditionAsts, updateAsts, Seq(bodyAst))

private def setOrderExplicitly(ast: Ast, order: Int): Ast = {
ast.root match {
case Some(value: ExpressionNew) => value.order(order); ast
case _ => ast
}
}

def forAst(
forNode: NewControlStructure,
locals: Seq[Ast],
Expand All @@ -206,12 +213,15 @@ abstract class AstCreatorBase(filename: String)(implicit withSchemaValidation: V
updateAsts: Seq[Ast],
bodyAsts: Seq[Ast]
): Ast = {
val lineNumber = forNode.lineNumber
val lineNumber = forNode.lineNumber
val numOfLocals = locals.size
// for the expected orders see CfgCreator.cfgForForStatement
if (bodyAsts.nonEmpty) setOrderExplicitly(bodyAsts.head, numOfLocals + 4)
Ast(forNode)
.withChildren(locals)
.withChild(wrapMultipleInBlock(initAsts, lineNumber))
.withChild(wrapMultipleInBlock(conditionAsts, lineNumber))
.withChild(wrapMultipleInBlock(updateAsts, lineNumber))
.withChild(setOrderExplicitly(wrapMultipleInBlock(initAsts, lineNumber), numOfLocals + 1))
.withChild(setOrderExplicitly(wrapMultipleInBlock(conditionAsts, lineNumber), numOfLocals + 2))
.withChild(setOrderExplicitly(wrapMultipleInBlock(updateAsts, lineNumber), numOfLocals + 3))
.withChildren(bodyAsts)
.withConditionEdges(forNode, conditionAsts.flatMap(_.root).toList)
}
Expand Down

0 comments on commit 97806c7

Please sign in to comment.