diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/util/V2ExpressionSQLBuilder.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/util/V2ExpressionSQLBuilder.java index 49afcd5ebcd50..cb132ab11326d 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/util/V2ExpressionSQLBuilder.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/util/V2ExpressionSQLBuilder.java @@ -98,7 +98,7 @@ public String build(Expression expr) { case "ENDS_WITH" -> visitEndsWith(build(e.children()[0]), build(e.children()[1])); case "CONTAINS" -> visitContains(build(e.children()[0]), build(e.children()[1])); case "=", "<>", "<=>", "<", "<=", ">", ">=" -> - visitBinaryComparison(name, inputToSQL(e.children()[0]), inputToSQL(e.children()[1])); + visitBinaryComparison(name, e.children()[0], e.children()[1]); case "+", "*", "/", "%", "&", "|", "^" -> visitBinaryArithmetic(name, inputToSQL(e.children()[0]), inputToSQL(e.children()[1])); case "-" -> { @@ -219,6 +219,10 @@ protected String inputToSQL(Expression input) { } } + protected String visitBinaryComparison(String name, Expression le, Expression re) { + return visitBinaryComparison(name, inputToSQL(le), inputToSQL(re)); + } + protected String visitBinaryComparison(String name, String l, String r) { if (name.equals("<=>")) { return "((" + l + " IS NOT NULL AND " + r + " IS NOT NULL AND " + l + " = " + r + ") " + diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/OracleDialect.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/OracleDialect.scala index adb0da1a21264..114b524dcd96a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/OracleDialect.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/OracleDialect.scala @@ -24,7 +24,7 @@ import scala.util.control.NonFatal import org.apache.spark.{SparkThrowable, SparkUnsupportedOperationException} import org.apache.spark.sql.catalyst.SQLConfHelper -import org.apache.spark.sql.connector.expressions.{Expression, GeneralScalarExpression, Literal} +import org.apache.spark.sql.connector.expressions.{Expression, Literal} import org.apache.spark.sql.errors.QueryCompilationErrors import org.apache.spark.sql.execution.datasources.jdbc.JDBCOptions import org.apache.spark.sql.jdbc.OracleDialect._ @@ -62,33 +62,27 @@ private case class OracleDialect() extends JdbcDialect with SQLConfHelper with N super.visitAggregateFunction(funcName, isDistinct, inputs) } + override def visitBinaryComparison(name: String, le: Expression, re: Expression): String = { + (le, re) match { + case (lhs: Literal[_], rhs: Expression) if lhs.dataType == BinaryType => + compareBlob(lhs, name, rhs) + case (lhs: Expression, rhs: Literal[_]) if rhs.dataType == BinaryType => + compareBlob(lhs, name, rhs) + case _ => + super.visitBinaryComparison(name, le, re); + } + } + private def compareBlob(lhs: Expression, operator: String, rhs: Expression): String = { val l = inputToSQL(lhs) val r = inputToSQL(rhs) - val op = if (operator == "<=>") "=" else operator - val compare = s"DBMS_LOB.COMPARE($l, $r) $op 0" if (operator == "<=>") { + val compare = s"DBMS_LOB.COMPARE($l, $r) = 0" s"(($l IS NOT NULL AND $r IS NOT NULL AND $compare) OR ($l IS NULL AND $r IS NULL))" } else { - compare + s"DBMS_LOB.COMPARE($l, $r) $operator 0" } } - - override def build(expr: Expression): String = expr match { - case e: GeneralScalarExpression => - e.name() match { - case "=" | "<>" | "<=>" | "<" | "<=" | ">" | ">=" => - (e.children()(0), e.children()(1)) match { - case (lhs: Literal[_], rhs: Expression) if lhs.dataType == BinaryType => - compareBlob(lhs, e.name, rhs) - case (lhs: Expression, rhs: Literal[_]) if rhs.dataType == BinaryType => - compareBlob(lhs, e.name, rhs) - case _ => super.build(expr) - } - case _ => super.build(expr) - } - case _ => super.build(expr) - } } override def compileExpression(expr: Expression): Option[String] = {