Skip to content

Commit

Permalink
improve logic to work with exception handlers
Browse files Browse the repository at this point in the history
  • Loading branch information
dusantism-db committed Feb 7, 2025
1 parent 7d3008e commit 901aa6c
Show file tree
Hide file tree
Showing 3 changed files with 116 additions and 47 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ class ResolveCatalogs(val catalogManager: CatalogManager)

private def resolveCreateVariableName(nameParts: Seq[String]): ResolvedIdentifier = {
val ident = SqlScriptingLocalVariableManager.get()
.filterNot(AnalysisContext.get.isExecuteImmediate)
.filterNot(_ => AnalysisContext.get.isExecuteImmediate)
.getOrElse(catalogManager.tempVariableManager)
.createIdentifier(nameParts.last)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ import org.apache.spark.sql.catalyst.catalog.{VariableDefinition, VariableManage
import org.apache.spark.sql.catalyst.expressions.Literal
import org.apache.spark.sql.connector.catalog.Identifier
import org.apache.spark.sql.errors.DataTypeErrorsBase
import org.apache.spark.sql.errors.QueryCompilationErrors.unresolvedVariableError
import org.apache.spark.sql.errors.QueryCompilationErrors.unresolvedVariableError

class SqlScriptingLocalVariableManager(context: SqlScriptingExecutionContext)
extends VariableManager with DataTypeErrorsBase {
Expand Down Expand Up @@ -60,42 +60,52 @@ class SqlScriptingLocalVariableManager(context: SqlScriptingExecutionContext)
initValue: Literal,
identifier: Identifier): Unit = {
def varDef = VariableDefinition(identifier, defaultValueSQL, initValue)
nameParts match {
// Unqualified case.
case Seq(name) =>
context.currentFrame.scopes
.findLast(_.variables.contains(name))
// Throw error if variable is not found. This shouldn't happen as the check is already
// done in SetVariableExec.
.orElse(throw unresolvedVariableError(nameParts, identifier.namespace().toIndexedSeq))
.map(_.variables.put(name, varDef))
findScopeOfVariable(nameParts)
.getOrElse(throw unresolvedVariableError(nameParts, identifier.namespace().toIndexedSeq))
.variables.put(nameParts.last, varDef)
}

override def get(nameParts: Seq[String]): Option[VariableDefinition] = {
findScopeOfVariable(nameParts).flatMap(_.variables.get(nameParts.last))
}

private def findScopeOfVariable(nameParts: Seq[String]): Option[SqlScriptingExecutionScope] = {
def isScopeOfVar(
nameParts: Seq[String],
scope: SqlScriptingExecutionScope
): Boolean = nameParts match {
case Seq(name) => scope.variables.contains(name)
// Qualified case.
case Seq(label, name) =>
context.currentFrame.scopes
.findLast(_.label == label)
.filter(_.variables.contains(name))
// Throw error if variable is not found. This shouldn't happen as the check is already
// done in SetVariableExec.
.orElse(throw unresolvedVariableError(nameParts, identifier.namespace().toIndexedSeq))
.map(_.variables.put(name, varDef))
case Seq(label, _) => scope.label == label
case _ =>
throw SparkException.internalError("ScriptingVariableManager.set expects 1 or 2 nameParts.")
throw SparkException.internalError("ScriptingVariableManager expects 1 or 2 nameParts.")
}
}

override def get(nameParts: Seq[String]): Option[VariableDefinition] = nameParts match {
// Unqualified case.
case Seq(name) =>
context.currentFrame.scopes
.findLast(_.variables.contains(name))
.flatMap(_.variables.get(name))
// Qualified case.
case Seq(label, name) =>
context.currentFrame.scopes
.findLast(_.label == label)
.flatMap(_.variables.get(name))
case _ =>
throw SparkException.internalError("ScriptingVariableManager.get expects 1 or 2 nameParts.")
// First search for variable in entire current frame.
val resCurrentFrame = context.currentFrame.scopes
.findLast(scope => isScopeOfVar(nameParts, scope))
if (resCurrentFrame.isDefined) {
return resCurrentFrame
}

// When searching in previous frames, for each frame we have to check only scopes before and
// including the scope where the previously checked frame is defined, as the frames
// should not access variables from scopes which are nested below it's definition.
var previousFrameDefinitionLabel = context.currentFrame.scopeLabel

context.frames.dropRight(1).reverseIterator.foreach(frame => {
val candidateScopes = frame.scopes.reverse.dropWhile(
scope => !previousFrameDefinitionLabel.contains(scope.label))

val scope = candidateScopes.findLast(scope => isScopeOfVar(nameParts, scope))
if (scope.isDefined) {
return scope
}
if (candidateScopes.nonEmpty) {
previousFrameDefinitionLabel = frame.scopeLabel
}
})
None
}

override def createIdentifier(name: String): Identifier =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2381,22 +2381,25 @@ class SqlScriptingExecutionSuite extends QueryTest with SharedSparkSession {
}

test("local variable - execute immediate create session var") {
val sqlScript =
"""
|BEGIN
| EXECUTE IMMEDIATE 'DECLARE sessionVar = 5';
| SELECT system.session.sessionVar;
| SELECT sessionVar;
|END
|""".stripMargin
val expected = Seq(
Seq(Row(5)), // select system.session.sessionVar
Seq(Row(5)) // select sessionVar
)
verifySqlScriptResult(sqlScript, expected)
withSessionVariable("sessionVar") {
val sqlScript =
"""
|BEGIN
| EXECUTE IMMEDIATE 'DECLARE sessionVar = 5';
| SELECT system.session.sessionVar;
| SELECT sessionVar;
|END
|""".stripMargin
val expected = Seq(
Seq(Row(5)), // select system.session.sessionVar
Seq(Row(5)) // select sessionVar
)
verifySqlScriptResult(sqlScript, expected)
}
}

test("local variable - execute immediate create qualified session var") {
withSessionVariable("sessionVar") {
val sqlScript =
"""
|BEGIN
Expand All @@ -2410,6 +2413,7 @@ class SqlScriptingExecutionSuite extends QueryTest with SharedSparkSession {
Seq(Row(5)) // select sessionVar
)
verifySqlScriptResult(sqlScript, expected)
}
}

test("local variable - execute immediate set session var") {
Expand All @@ -2435,4 +2439,59 @@ class SqlScriptingExecutionSuite extends QueryTest with SharedSparkSession {
verifySqlScriptResult(sqlScript, expected)
}
}

test("local variable - handlers - triple chained handlers") {
val sqlScript =
"""
|BEGIN
| DECLARE OR REPLACE VARIABLE varOuter INT = 0;
| l1: BEGIN
| DECLARE OR REPLACE VARIABLE varL1 INT = 1;
| DECLARE EXIT HANDLER FOR SQLEXCEPTION
| BEGIN
| SELECT varOuter;
| SELECT varL1;
| END;
| l2: BEGIN
| DECLARE OR REPLACE VARIABLE varL2 = 2;
| DECLARE EXIT HANDLER FOR SQLEXCEPTION
| BEGIN
| SELECT varOuter;
| SELECT varL1;
| SELECT varL2;
| SELECT 1/0;
| END;
| l3: BEGIN
| DECLARE OR REPLACE VARIABLE varL3 = 3;
| DECLARE EXIT HANDLER FOR SQLEXCEPTION
| BEGIN
| SELECT varOuter;
| SELECT varL1;
| SELECT varL2;
| SELECT varL3;
| SELECT 1/0;
| END;
| SELECT 5;
| SELECT 1/0;
| SELECT 6;
| END;
| END;
| END;
|END
|""".stripMargin
val expected = Seq(
Seq(Row(5)),
Seq(Row(0)),
Seq(Row(1)),
Seq(Row(2)),
Seq(Row(3)),
Seq(Row(0)),
Seq(Row(1)),
Seq(Row(2)),
Seq(Row(0)),
Seq(Row(1))
)
verifySqlScriptResult(sqlScript, expected = expected)
}
}

0 comments on commit 901aa6c

Please sign in to comment.