Skip to content

Commit

Permalink
[ruby] Move lowering of ||= and &&= to AstCreator (#5055)
Browse files Browse the repository at this point in the history
* Moved lowering for ||= and &&= to AstCreator

* Moved lowering func to AstCreatorHelper trait
  • Loading branch information
AndreiDreyer authored Nov 5, 2024
1 parent 93876e5 commit 80145a8
Show file tree
Hide file tree
Showing 7 changed files with 99 additions and 55 deletions.
Original file line number Diff line number Diff line change
@@ -1,11 +1,17 @@
package io.joern.rubysrc2cpg.astcreation
import io.joern.rubysrc2cpg.astcreation.RubyIntermediateAst.{
ClassFieldIdentifier,
ControlFlowStatement,
DummyNode,
IfExpression,
InstanceFieldIdentifier,
MemberAccess,
RubyExpression,
RubyFieldIdentifier,
RubyExpression
SingleAssignment,
StatementList,
TextSpan,
UnaryExpression
}
import io.joern.rubysrc2cpg.datastructures.{BlockScope, FieldDecl}
import io.joern.rubysrc2cpg.passes.Defines
Expand Down Expand Up @@ -146,6 +152,44 @@ trait AstCreatorHelper(implicit withSchemaValidation: ValidationMode) { this: As
member
}

/** Lowers the `||=` and `&&=` assignment operators to the respective `.nil?` checks
*/
def lowerAssignmentOperator(lhs: RubyExpression, rhs: RubyExpression, op: String, span: TextSpan): RubyExpression &
ControlFlowStatement = {
val condition = nilCheckCondition(lhs, op, "nil?", span)
val thenClause = nilCheckThenClause(lhs, rhs, span)
nilCheckIfStatement(condition, thenClause, span)
}

/** Generates the required `.nil?` check condition used in the lowering of `||=` and `&&=`
*/
private def nilCheckCondition(lhs: RubyExpression, op: String, memberName: String, span: TextSpan): RubyExpression = {
val memberAccess =
MemberAccess(lhs, op = ".", memberName = "nil?")(span.spanStart(s"${lhs.span.text}.nil?"))
if op == "||=" then memberAccess
else UnaryExpression(op = "!", expression = memberAccess)(span.spanStart(s"!${memberAccess.span.text}"))
}

/** Generates the assignment and the `thenClause` used in the lowering of `||=` and `&&=`
*/
private def nilCheckThenClause(lhs: RubyExpression, rhs: RubyExpression, span: TextSpan): RubyExpression = {
StatementList(List(SingleAssignment(lhs, "=", rhs)(span.spanStart(s"${lhs.span.text} = ${rhs.span.text}"))))(
span.spanStart(s"${lhs.span.text} = ${rhs.span.text}")
)
}

/** Generates the if statement for the lowering of `||=` and `&&=`
*/
private def nilCheckIfStatement(
condition: RubyExpression,
thenClause: RubyExpression,
span: TextSpan
): RubyExpression & ControlFlowStatement = {
IfExpression(condition = condition, thenClause = thenClause, elsifClauses = List.empty, elseClause = None)(
span.spanStart(s"if ${condition.span.text} then ${thenClause.span.text} end")
)
}

protected val UnaryOperatorNames: Map[String, String] = Map(
"!" -> Operators.logicalNot,
"not" -> Operators.logicalNot,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import io.joern.rubysrc2cpg.astcreation.RubyIntermediateAst.{
IfExpression,
MemberCall,
NextExpression,
OperatorAssignment,
RescueExpression,
ReturnExpression,
RubyExpression,
Expand All @@ -25,6 +26,7 @@ import io.joern.rubysrc2cpg.astcreation.RubyIntermediateAst.{
WhenClause,
WhileExpression
}
import io.joern.rubysrc2cpg.parser.RubyJsonHelpers
import io.joern.rubysrc2cpg.passes.Defines
import io.joern.x2cpg.{Ast, ValidationMode}
import io.shiftleft.codepropertygraph.generated.{ControlStructureTypes, DispatchTypes, Operators}
Expand All @@ -39,16 +41,17 @@ import io.shiftleft.codepropertygraph.generated.nodes.{
trait AstForControlStructuresCreator(implicit withSchemaValidation: ValidationMode) { this: AstCreator =>

protected def astForControlStructureExpression(node: ControlFlowStatement): Ast = node match {
case node: WhileExpression => astForWhileStatement(node)
case node: DoWhileExpression => astForDoWhileStatement(node)
case node: UntilExpression => astForUntilStatement(node)
case node: CaseExpression => blockAst(NewBlock(), astsForCaseExpression(node).toList)
case node: IfExpression => astForIfExpression(node)
case node: UnlessExpression => astForUnlessStatement(node)
case node: ForExpression => astForForExpression(node)
case node: RescueExpression => astForRescueExpression(node)
case node: NextExpression => astForNextExpression(node)
case node: BreakExpression => astForBreakExpression(node)
case node: WhileExpression => astForWhileStatement(node)
case node: DoWhileExpression => astForDoWhileStatement(node)
case node: UntilExpression => astForUntilStatement(node)
case node: CaseExpression => blockAst(NewBlock(), astsForCaseExpression(node).toList)
case node: IfExpression => astForIfExpression(node)
case node: UnlessExpression => astForUnlessStatement(node)
case node: ForExpression => astForForExpression(node)
case node: RescueExpression => astForRescueExpression(node)
case node: NextExpression => astForNextExpression(node)
case node: BreakExpression => astForBreakExpression(node)
case node: OperatorAssignment => astForOperatorAssignmentExpression(node)
}

private def astForWhileStatement(node: WhileExpression): Ast = {
Expand Down Expand Up @@ -295,4 +298,9 @@ trait AstForControlStructuresCreator(implicit withSchemaValidation: ValidationMo
astsForStatement(generatedNode)
}

private def astForOperatorAssignmentExpression(node: OperatorAssignment): Ast = {
val loweredAssignment = lowerAssignmentOperator(node.lhs, node.rhs, node.op, node.span)
astForControlStructureExpression(loweredAssignment)
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package io.joern.rubysrc2cpg.astcreation

import io.joern.rubysrc2cpg.astcreation.RubyIntermediateAst.{Unknown, Block as RubyBlock, *}
import io.joern.rubysrc2cpg.datastructures.BlockScope
import io.joern.rubysrc2cpg.parser.RubyJsonHelpers
import io.joern.rubysrc2cpg.passes.Defines
import io.joern.rubysrc2cpg.passes.GlobalTypes
import io.joern.rubysrc2cpg.passes.Defines.{RubyOperators, getBuiltInType}
Expand Down Expand Up @@ -482,7 +483,14 @@ trait AstForExpressionsCreator(implicit withSchemaValidation: ValidationMode) {
elseAssignNil
)

astForExpression(transform(cfNode))
cfNode match {
case x @ OperatorAssignment(lhs, op, rhs) =>
val loweredNode = lowerAssignmentOperator(lhs, rhs, op, x.span)
astForExpression(transform(loweredNode))
case x =>
astForExpression(transform(cfNode))
}

case _ =>
// The if the LHS defines a new variable, put the local variable into scope
val lhsAst = node.lhs match {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package io.joern.rubysrc2cpg.astcreation

import io.joern.rubysrc2cpg.astcreation.RubyIntermediateAst.{RubyStatement, *}
import io.joern.rubysrc2cpg.datastructures.BlockScope
import io.joern.rubysrc2cpg.parser.RubyJsonHelpers
import io.joern.rubysrc2cpg.passes.Defines
import io.joern.rubysrc2cpg.passes.Defines.getBuiltInType
import io.joern.x2cpg.{Ast, ValidationMode}
Expand All @@ -14,6 +15,7 @@ trait AstForStatementsCreator(implicit withSchemaValidation: ValidationMode) { t
baseAstCache.clear() // A safe approximation on where to reset the cache
node match {
case node: IfExpression => astForIfStatement(node)
case node: OperatorAssignment => astForOperatorAssignment(node)
case node: CaseExpression => astsForCaseExpression(node)
case node: StatementList => astForStatementList(node) :: Nil
case node: ReturnExpression => astForReturnExpression(node) :: Nil
Expand Down Expand Up @@ -48,6 +50,11 @@ trait AstForStatementsCreator(implicit withSchemaValidation: ValidationMode) { t
}
}

private def astForOperatorAssignment(node: OperatorAssignment): Seq[Ast] = {
val loweredAssignment = lowerAssignmentOperator(node.lhs, node.rhs, node.op, node.span)
astsForStatement(loweredAssignment)
}

private def astForJsonIfStatement(node: IfExpression): Seq[Ast] = {
val conditionAst = astForExpression(node.condition)
val thenAst = astForThenClause(node.thenClause)
Expand Down Expand Up @@ -177,7 +184,14 @@ trait AstForStatementsCreator(implicit withSchemaValidation: ValidationMode) { t
case expr: ControlFlowStatement =>
def transform(e: RubyExpression & ControlFlowStatement): RubyExpression =
transformLastRubyNodeInControlFlowExpressionBody(e, returnLastNode(_, transform), elseReturnNil)
astsForStatement(transform(expr))

expr match {
case x @ OperatorAssignment(lhs, op, rhs) =>
val loweredAssignment = lowerAssignmentOperator(lhs, rhs, op, x.span)
astsForStatement(transform(loweredAssignment))
case x =>
astsForStatement(transform(expr))
}
case node: MemberCallWithBlock => returnAstForRubyCall(node)
case node: SimpleCallWithBlock => returnAstForRubyCall(node)
case _: (LiteralExpr | BinaryExpression | UnaryExpression | SimpleIdentifier | SelfIdentifier | IndexAccess |
Expand Down Expand Up @@ -208,9 +222,7 @@ trait AstForStatementsCreator(implicit withSchemaValidation: ValidationMode) { t
val nilReturnLiteral = StaticLiteral(Defines.NilClass)(nilReturnSpan)
stmts.map(astForExpression) ++ astsForImplicitReturnStatement(nilReturnLiteral)
case node =>
logger.warn(
s"Implicit return here not supported yet: ${node.text} (${node.getClass.getSimpleName}), only generating statement"
)
logger.warn(s" not supported yet: ${node.text} (${node.getClass.getSimpleName}), only generating statement")
astsForStatement(node).toList
}

Expand Down Expand Up @@ -311,6 +323,9 @@ trait AstForStatementsCreator(implicit withSchemaValidation: ValidationMode) { t
case WhileExpression(condition, body) => WhileExpression(condition, transform(body))(node.span)
case DoWhileExpression(condition, body) => DoWhileExpression(condition, transform(body))(node.span)
case UntilExpression(condition, body) => UntilExpression(condition, transform(body))(node.span)
case OperatorAssignment(lhs, op, rhs) =>
val loweredNode = lowerAssignmentOperator(lhs, rhs, op, node.span)
transformLastRubyNodeInControlFlowExpressionBody(loweredNode, transform, defaultElseBranch)
case IfExpression(condition, thenClause, elsifClauses, elseClause) =>
IfExpression(
condition,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,11 @@ object RubyIntermediateAst {
def assignments: List[SingleAssignment]
}

final case class OperatorAssignment(lhs: RubyExpression, op: String, rhs: RubyExpression)(span: TextSpan)
extends RubyExpression(span)
with RubyStatement
with ControlFlowStatement

final case class DefaultMultipleAssignment(assignments: List[SingleAssignment])(span: TextSpan)
extends RubyExpression(span)
with MultipleAssignment
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import io.joern.rubysrc2cpg.astcreation.RubyIntermediateAst.{
AllowedTypeDeclarationChild,
ArrayLiteral,
ClassFieldIdentifier,
ControlFlowStatement,
DefaultMultipleAssignment,
IfExpression,
MemberAccess,
Expand Down Expand Up @@ -167,43 +168,6 @@ object RubyJsonHelpers {
}
}

/** Lowers the `||=` and `&&=` assignment operators to the respective `.nil?` checks
*/
def lowerAssignmentOperator(lhs: RubyExpression, rhs: RubyExpression, op: String, span: TextSpan): RubyExpression = {
val condition = nilCheckCondition(lhs, op, "nil?", span)
val thenClause = nilCheckThenClause(lhs, rhs, span)
nilCheckIfStatement(condition, thenClause, span)
}

/** Generates the required `.nil?` check condition used in the lowering of `||=` and `&&=`
*/
private def nilCheckCondition(lhs: RubyExpression, op: String, memberName: String, span: TextSpan): RubyExpression = {
val memberAccess =
MemberAccess(lhs, op = ".", memberName = "nil?")(span.spanStart(s"${lhs.span.text}.nil?"))
if op == "||=" then memberAccess
else UnaryExpression(op = "!", expression = memberAccess)(span.spanStart(s"!${memberAccess.span.text}"))
}

/** Generates the assignment and the `thenClause` used in the lowering of `||=` and `&&=`
*/
private def nilCheckThenClause(lhs: RubyExpression, rhs: RubyExpression, span: TextSpan): RubyExpression = {
StatementList(List(SingleAssignment(lhs, "=", rhs)(span.spanStart(s"${lhs.span.text} = ${rhs.span.text}"))))(
span.spanStart(s"${lhs.span.text} = ${rhs.span.text}")
)
}

/** Generates the if statement for the lowering of `||=` and `&&=`
*/
private def nilCheckIfStatement(
condition: RubyExpression,
thenClause: RubyExpression,
span: TextSpan
): RubyExpression = {
IfExpression(condition = condition, thenClause = thenClause, elsifClauses = List.empty, elseClause = None)(
span.spanStart(s"if ${condition.span.text} then ${thenClause.span.text} end")
)
}

def lowerMultipleAssignment(
obj: ujson.Obj,
lhsNodes: List[RubyExpression],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,7 @@ class RubyJsonToNodeCreator(
case x => x
}
val rhs = visit(obj(ParserKeys.Rhs))
lowerAssignmentOperator(lhs, rhs, "&&=", obj.toTextSpan)
OperatorAssignment(lhs, "&&=", rhs)(obj.toTextSpan)
}

private def visitArg(obj: Obj): RubyExpression = MandatoryParameter(obj(ParserKeys.Value).str)(obj.toTextSpan)
Expand Down Expand Up @@ -732,7 +732,7 @@ class RubyJsonToNodeCreator(
case x => x
}
val rhs = visit(obj(ParserKeys.Rhs))
lowerAssignmentOperator(lhs, rhs, "||=", obj.toTextSpan)
OperatorAssignment(lhs, "||=", rhs)(obj.toTextSpan)
}

private def visitPair(obj: Obj): RubyExpression = {
Expand Down

0 comments on commit 80145a8

Please sign in to comment.