Skip to content

Commit

Permalink
[ruby] In Pattern Variable Scoping Fix (#5208)
Browse files Browse the repository at this point in the history
Fixed an issue where the variables that the pattern extracted to were interpreted as fields instead of local variables. Also makes sure the pattern match call happens on the original expression and not on the LHS match variable.
  • Loading branch information
DavidBakerEffendi authored Jan 7, 2025
1 parent 2842a88 commit 937c2ff
Show file tree
Hide file tree
Showing 4 changed files with 21 additions and 14 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -299,25 +299,22 @@ trait AstForControlStructuresCreator(implicit withSchemaValidation: ValidationMo
case x: ArrayPattern =>
val condition = expr.map(e => BinaryExpression(x, "===", e)(x.span)).getOrElse(inClause.pattern)
val body = inClause.body
val variables = x.children.collect { case x: MatchVariable => x }

val variables = x.children.collect { case x: MatchVariable =>
x
}

val conditionBody = if (variables.nonEmpty) {
StatementList(variables.map { x =>
val lhs = SimpleIdentifier()(x.span)
SingleAssignment(lhs, "=", x)(
val conditionBody = if (variables.nonEmpty && expr.isDefined) {
StatementList(variables.map { lhs =>
SingleAssignment(lhs, "=", MatchVariable()(expr.get.span))(
inClause.span
.spanStart(s"${lhs.span.text} = ${RubyOperators.arrayPatternMatch}(${lhs.span.text})")
.spanStart(s"${lhs.span.text} = ${RubyOperators.arrayPatternMatch}(${expr.get.text})")
)
} :+ body)(body.span)
} else {
body
}

(condition, conditionBody)
case x => (x, inClause.body)
case x =>
(x, inClause.body)
}

val conditional = IfExpression(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -520,6 +520,11 @@ trait AstForExpressionsCreator(implicit withSchemaValidation: ValidationMode) {
val asts = astsForStatement(x.multipleAssignment)
val call = callNode(node, code(node), op, op, DispatchTypes.STATIC_DISPATCH)
return callAst(call, asts :+ rhsAst)
case x: MatchVariable =>
handleVariableOccurrence(x.toSimpleIdentifier) // Create local variable under this scope
val matchIden = astForExpression(x.toSimpleIdentifier)
val call = callNode(node, code(node), op, op, DispatchTypes.STATIC_DISPATCH)
return callAst(call, matchIden :: rhsAst :: Nil)
case _ => astForExpression(node.lhs)
}

Expand Down Expand Up @@ -618,7 +623,10 @@ trait AstForExpressionsCreator(implicit withSchemaValidation: ValidationMode) {
protected def astForArrayPattern(node: ArrayPattern): Ast = {
val callNode_ =
callNode(node, code(node), Operators.arrayInitializer, Operators.arrayInitializer, DispatchTypes.STATIC_DISPATCH)
val childrenAst = node.children.map(astForExpression)
val childrenAst = node.children.map {
case x: MatchVariable => astForExpression(SimpleIdentifier()(x.span))
case x => astForExpression(x)
}

callAst(callNode_, childrenAst)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -348,7 +348,9 @@ object RubyIntermediateAst {

final case class ArrayPattern(children: List[RubyExpression])(span: TextSpan) extends RubyExpression(span)

final case class MatchVariable()(span: TextSpan) extends RubyExpression(span)
final case class MatchVariable()(span: TextSpan) extends RubyExpression(span) {
def toSimpleIdentifier: SimpleIdentifier = SimpleIdentifier()(span)
}

final case class NextExpression()(span: TextSpan) extends RubyExpression(span) with ControlFlowStatement

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -202,7 +202,7 @@ class CaseTests extends RubyCode2CpgFixture {
lhs.name shouldBe "result"

rhs.methodFullName shouldBe RubyOperators.arrayPatternMatch
rhs.code shouldBe s"${RubyOperators.arrayPatternMatch}(result)"
rhs.code shouldBe s"${RubyOperators.arrayPatternMatch}(<tmp-0>)"
case xs => fail(s"Expected lhs and rhs, got [${xs.code.mkString(",")}]")
}

Expand All @@ -211,7 +211,7 @@ class CaseTests extends RubyCode2CpgFixture {
lhs.name shouldBe "notResult"

rhs.methodFullName shouldBe RubyOperators.arrayPatternMatch
rhs.code shouldBe s"${RubyOperators.arrayPatternMatch}(notResult)"
rhs.code shouldBe s"${RubyOperators.arrayPatternMatch}(<tmp-0>)"
case xs => fail(s"Expected lhs and rhs, got [${xs.code.mkString(",")}]")
}
case _ => fail(s"Expected two true branches")
Expand Down

0 comments on commit 937c2ff

Please sign in to comment.