Skip to content

Commit

Permalink
Add support for java type pattern matching (#5140)
Browse files Browse the repository at this point in the history
  • Loading branch information
johannescoetzee authored Nov 27, 2024
1 parent f48ac0f commit c52986f
Show file tree
Hide file tree
Showing 17 changed files with 3,475 additions and 98 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,8 @@ import io.joern.javasrc2cpg.util.{
BindingTable,
BindingTableAdapterForJavaparser,
MultiBindingTableAdapterForJavaparser,
NameConstants
NameConstants,
TemporaryNameProvider
}
import io.joern.x2cpg.datastructures.Global
import io.joern.x2cpg.utils.OffsetUtils
Expand Down Expand Up @@ -105,6 +106,8 @@ class AstCreator(
TypeInfoCalculator(global, symbolSolver, keepTypeArguments)
private[astcreation] val bindingTableCache = mutable.HashMap.empty[String, BindingTable]

private[astcreation] val tempNameProvider: TemporaryNameProvider = new TemporaryNameProvider

/** Entry point of AST creation. Translates a compilation unit created by JavaParser into a DiffGraph containing the
* corresponding CPG AST.
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,9 @@ private[declarations] trait AstForMethodsCreator { this: AstCreator =>
scope.enclosingMethod.get.addParameter(node)
}

val bodyAst = methodDeclaration.getBody.toScala.map(astForBlockStatement(_)).getOrElse(Ast(NewBlock()))
val bodyAst = methodDeclaration.getBody.toScala
.map(astForBlockStatement(_, includeTemporaryLocals = true))
.getOrElse(Ast(NewBlock()))
val (lineNr, columnNr) = tryWithSafeStackOverflow(methodDeclaration.getType) match {
case Success(typ) => (line(typ), column(typ))
case Failure(_) => (line(methodDeclaration), column(methodDeclaration))
Expand Down Expand Up @@ -173,8 +175,9 @@ private[declarations] trait AstForMethodsCreator { this: AstCreator =>
fieldDeclaration.getVariables.asScala.filter(_.getInitializer.isPresent).toList.flatMap { variableDeclaration =>
scope.pushFieldDeclScope(fieldDeclaration.isStatic, variableDeclaration.getNameAsString)
val assignmentAsts = astsForVariableDeclarator(variableDeclaration, fieldDeclaration)
val patternAsts = scope.enclosingMethod.get.getUnaddedPatternVariableAstsAndMarkAdded()
scope.popFieldDeclScope()
assignmentAsts
patternAsts ++ assignmentAsts
}
}
}
Expand All @@ -198,7 +201,8 @@ private[declarations] trait AstForMethodsCreator { this: AstCreator =>

val thisNode = thisNodeForMethod(typeFullName, lineNumber = None)
scope.enclosingMethod.foreach(_.addParameter(thisNode))
val bodyStatementAsts = astsForFieldInitializers(instanceFieldDeclarations)
val bodyStatementAsts = astsForFieldInitializers(instanceFieldDeclarations)
val temporaryLocalAsts = scope.enclosingMethod.map(_.getTemporaryLocals).getOrElse(Nil).map(Ast(_))

val returnNode = newMethodReturnNode(TypeConstants.Void, line = None, column = None)

Expand All @@ -209,7 +213,7 @@ private[declarations] trait AstForMethodsCreator { this: AstCreator =>
constructorNode,
thisNode,
explicitParameterAsts = Nil,
bodyStatementAsts = bodyStatementAsts,
bodyStatementAsts = temporaryLocalAsts ++ bodyStatementAsts,
methodReturn = returnNode,
annotationAsts = Nil,
modifiers = modifiers,
Expand Down Expand Up @@ -345,18 +349,21 @@ private[declarations] trait AstForMethodsCreator { this: AstCreator =>

scope.pushBlockScope()
val bodyStatements = constructorDeclaration.getBody.getStatements.asScala.toList
val statementsAsts = bodyStatements.flatMap(astsForStatement)
val bodyContainsThis = bodyStatements.headOption
.collect { case consInvocation: ExplicitConstructorInvocationStmt => consInvocation.isThis }
.getOrElse(false)
val fieldAssignments =
val fieldAssignmentsAndTempLocals =
if (bodyContainsThis)
Nil
else
astsForFieldInitializers(instanceFieldDeclarations)
scope.enclosingMethod.get.getTemporaryLocals.map(Ast(_)) ++ astsForFieldInitializers(
instanceFieldDeclarations
)

// The this(...) call must always be the first statement in the body, but adding the fieldAssignments
// The this(...) call must always be the first statement in the body, but adding the fieldAssignmentsAndTempLocals
// before the body asts here is safe, since the list will be empty if the body does start with this()
val bodyAsts = fieldAssignments ++ bodyStatements.flatMap(astsForStatement)
val bodyAsts = fieldAssignmentsAndTempLocals ++ statementsAsts
scope.popBlockScope()
val methodReturn = constructorReturnNode(constructorDeclaration)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ import com.github.javaparser.resolution.declarations.{
ResolvedReferenceTypeDeclaration,
ResolvedTypeParameterDeclaration
}
import io.joern.javasrc2cpg.astcreation.AstCreator
import io.joern.javasrc2cpg.astcreation.{AstCreator, ExpectedType}
import io.joern.javasrc2cpg.typesolvers.TypeInfoCalculator.TypeConstants
import io.joern.javasrc2cpg.util.{BindingTable, BindingTableEntry, NameConstants, Util}
import io.joern.x2cpg.utils.NodeBuilders.*
Expand Down Expand Up @@ -443,14 +443,18 @@ private[declarations] trait AstForTypeDeclsCreator { this: AstCreator =>
}

private def getStaticFieldInitializers(staticFields: List[FieldDeclaration]): List[Ast] = {
staticFields.flatMap { field =>
scope.pushMethodScope(NewMethod(), ExpectedType.empty, isStatic = true)
val fieldsAsts = staticFields.flatMap { field =>
field.getVariables.asScala.toList.flatMap { variable =>
scope.pushFieldDeclScope(isStatic = true, name = variable.getNameAsString)
val assignment = astsForVariableDeclarator(variable, field)
scope.popFieldDeclScope()
assignment
}
}
val methodScope = scope.popMethodScope()
methodScope.getTemporaryLocals.map(Ast(_)) ++ methodScope
.getUnaddedPatternVariableAstsAndMarkAdded() ++ fieldsAsts
}

private[declarations] def astForAnnotationExpr(annotationExpr: AnnotationExpr): Ast = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,6 @@ trait AstForCallExpressionsCreator { this: AstCreator =>

private val logger = LoggerFactory.getLogger(this.getClass)

private var tempConstCount = 0

private[expressions] def astForMethodCall(call: MethodCallExpr, expectedReturnType: ExpectedType): Ast = {
val maybeResolvedCall = tryWithSafeStackOverflow(call.resolve())
val argumentAsts = argAstsForCall(call, maybeResolvedCall, call.getArguments)
Expand Down Expand Up @@ -143,8 +141,7 @@ trait AstForCallExpressionsCreator { this: AstCreator =>
}

private[expressions] def blockAstForObjectCreationExpr(expr: ObjectCreationExpr, expectedType: ExpectedType): Ast = {
val tmpName = "$obj" ++ tempConstCount.toString
tempConstCount += 1
val tmpName = tempNameProvider.next

// Use an untyped identifier for receiver here, create the alloc and init ASTs,
// then use the types of those to fix the local type.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ trait AstForExpressionsCreator
with AstForLambdasCreator
with AstForCallExpressionsCreator
with AstForNameExpressionsCreator
with AstForPatternExpressionsCreator
with AstForVarDeclAndAssignsCreator { this: AstCreator =>
def astsForExpression(expression: Expression, expectedType: ExpectedType): Seq[Ast] = {
// TODO: Implement missing handlers
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,15 @@ import io.joern.javasrc2cpg.scope.Scope.{
NotInScope,
ScopeMember,
ScopeParameter,
ScopePatternVariable,
ScopeVariable,
SimpleVariable
}
import org.slf4j.LoggerFactory
import io.joern.x2cpg.utils.AstPropertiesUtil.*
import io.shiftleft.codepropertygraph.generated.Operators
import io.joern.x2cpg.utils.NodeBuilders.{newIdentifierNode, newOperatorCallNode}
import io.joern.javasrc2cpg.scope.PatternVariableInfo

trait AstForNameExpressionsCreator { this: AstCreator =>

Expand All @@ -54,6 +56,16 @@ trait AstForNameExpressionsCreator { this: AstCreator =>
variable.typeFullName
)

case SimpleVariable(ScopePatternVariable(localNode, typePatternExpr)) =>
scope.enclosingMethod.flatMap(_.getPatternVariableInfo(typePatternExpr)) match {
case Some(PatternVariableInfo(typePatternExpr, _, initializerAst, _, false)) =>
scope.enclosingMethod.foreach(_.registerPatternVariableInitializerToBeAddedToGraph(typePatternExpr))
initializerAst
case _ =>
val identifier = identifierNode(nameExpr, localNode.name, localNode.name, localNode.typeFullName)
Ast(identifier).withRefEdge(identifier, localNode)
}

case SimpleVariable(variable) =>
val identifier = identifierNode(nameExpr, name, name, typeFullName.getOrElse(TypeConstants.Any))
val captured = variable.node match {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
package io.joern.javasrc2cpg.astcreation.expressions

import com.github.javaparser.ast.Node
import com.github.javaparser.ast.expr.{
Expression,
InstanceOfExpr,
NameExpr,
PatternExpr,
RecordPatternExpr,
TypePatternExpr
}
import io.joern.javasrc2cpg.astcreation.{AstCreator, ExpectedType}
import io.joern.javasrc2cpg.jartypereader.model.Model.TypeConstants
import io.joern.javasrc2cpg.scope.Scope.NewVariableNode
import io.joern.x2cpg.Ast
import io.joern.x2cpg.utils.AstPropertiesUtil.*
import io.shiftleft.codepropertygraph.generated.nodes.{AstNodeNew, NewIdentifier}
import io.joern.x2cpg.utils.NodeBuilders.*
import io.shiftleft.codepropertygraph.generated.Operators

import scala.jdk.CollectionConverters.*

trait AstForPatternExpressionsCreator { this: AstCreator =>

private[astcreation] def astIdentifierAndRefsForPatternLhs(
rootNode: Node,
patternInitAst: Ast
): (Ast, NewIdentifier, Option[NewVariableNode]) = {
patternInitAst.nodes.toList match {
case (identifier: NewIdentifier) :: Nil =>
(patternInitAst, identifier, scope.lookupVariable(identifier.name).variableNode)

case _ =>
val tmpName = tempNameProvider.next
val tmpType = patternInitAst.rootType.getOrElse(TypeConstants.Object)
val tmpLocal = localNode(rootNode, tmpName, tmpName, tmpType)
val tmpIdentifier = identifierNode(rootNode, tmpName, tmpName, tmpType)

val tmpAssignmentNode =
newOperatorCallNode(
Operators.assignment,
s"$tmpName = ${patternInitAst.rootCodeOrEmpty}",
Option(tmpType),
line(rootNode),
column(rootNode)
)

// Don't need to add the local to the block scope since the only identifiers referencing it are created here
// (so a lookup for the local will never be done)
scope.enclosingMethod.foreach(_.addTemporaryLocal(tmpLocal))

(
callAst(tmpAssignmentNode, Ast(tmpIdentifier) :: patternInitAst :: Nil).withRefEdge(tmpIdentifier, tmpLocal),
tmpIdentifier,
Option(tmpLocal)
)
}
}

private[astcreation] def astForInstanceOfWithPattern(
instanceOfLhsExpr: Expression,
patternLhsInitAst: Ast,
pattern: PatternExpr
): Ast = {
val (lhsAst, lhsIdentifier, lhsRefsTo) = astIdentifierAndRefsForPatternLhs(instanceOfLhsExpr, patternLhsInitAst)

val patternTypeFullName = {
tryWithSafeStackOverflow(pattern.getType).toOption
.flatMap(typ => scope.lookupScopeType(typ.asString()).map(_.typeFullName).orElse(typeInfoCalc.fullName(typ)))
.getOrElse(TypeConstants.Any)
}

val patternTypeRef = typeRefNode(pattern.getType, code(pattern.getType), patternTypeFullName)

val typePatterns = getTypePatterns(pattern)

typePatterns.foreach { typePatternExpr =>
val variableName = typePatternExpr.getNameAsString
val variableType = {
tryWithSafeStackOverflow(typePatternExpr.getType).toOption
.flatMap(typ => scope.lookupScopeType(typ.asString()).map(_.typeFullName).orElse(typeInfoCalc.fullName(typ)))
.getOrElse(TypeConstants.Any)
}
val variableTypeCode = tryWithSafeStackOverflow(code(typePatternExpr.getType)).getOrElse(variableType)

val patternLocal = localNode(typePatternExpr, variableName, code(typePatternExpr), variableType)
val patternIdentifier = identifierNode(typePatternExpr, variableName, variableName, variableType)
// TODO Handle record pattern initializers
val patternInitializerCastType = typeRefNode(typePatternExpr, code(typePatternExpr.getType), variableType)
val patternInitializerCastRhs = lhsIdentifier.copy
val patternInitializerCast = newOperatorCallNode(
Operators.cast,
s"($variableTypeCode) ${lhsIdentifier.code}",
Option(variableType),
line(typePatternExpr),
column(typePatternExpr)
)

val initializerCastAst =
callAst(patternInitializerCast, Ast(patternInitializerCastType) :: Ast(patternInitializerCastRhs) :: Nil)
.withRefEdges(patternInitializerCastRhs, lhsRefsTo.toList)

val initializerAssignmentCall = newOperatorCallNode(
Operators.assignment,
s"$variableName = ${patternInitializerCast.code}",
Option(variableType),
line(typePatternExpr),
column(typePatternExpr)
)
val initializerAssignmentAst = callAst(
initializerAssignmentCall,
Ast(patternIdentifier) :: initializerCastAst :: Nil
).withRefEdge(patternIdentifier, patternLocal)

scope.enclosingMethod.foreach { methodScope =>
methodScope.putPatternVariableInfo(typePatternExpr, patternLocal, initializerAssignmentAst)
}
}

val instanceOfCall = newOperatorCallNode(
Operators.instanceOf,
s"${lhsAst.rootCodeOrEmpty} instanceof ${code(pattern.getType)}",
Option(TypeConstants.Boolean)
)

callAst(instanceOfCall, lhsAst :: Ast(patternTypeRef) :: Nil)
}

private def getTypePatterns(expr: PatternExpr): List[TypePatternExpr] = {
expr match {
case typePatternExpr: TypePatternExpr => typePatternExpr :: Nil

case recordPatternExpr: RecordPatternExpr =>
recordPatternExpr.getPatternList.asScala.toList.flatMap(getTypePatterns)
}
}
}
Loading

0 comments on commit c52986f

Please sign in to comment.