From ee619fb3ccfca3fe729519d1e0a99e9c2da02845 Mon Sep 17 00:00:00 2001 From: osopardo1 Date: Thu, 21 Sep 2023 16:35:49 +0200 Subject: [PATCH] Fix AnsiCast and update README --- README.md | 5 +- .../internal/rules/QbeastAnalysisUtils.scala | 63 ++++++++++++++++--- 2 files changed, 56 insertions(+), 12 deletions(-) diff --git a/README.md b/README.md index 39be82260..b47fae715 100644 --- a/README.md +++ b/README.md @@ -18,9 +18,9 @@ **Qbeast Spark** is an extension for [**Data Lakehouses**](http://cidrdb.org/cidr2021/papers/cidr2021_paper17.pdf) that enables **multi-dimensional filtering** and **sampling** directly on the storage -[![apache-spark](https://img.shields.io/badge/apache--spark-3.3.x-blue)](https://spark.apache.org/releases/spark-release-3-2-2.html) +[![apache-spark](https://img.shields.io/badge/apache--spark-3.4.x-blue)](https://spark.apache.org/releases/spark-release-3-4-1.html) [![apache-hadoop](https://img.shields.io/badge/apache--hadoop-3.3.x-blue)](https://hadoop.apache.org/release/3.3.1.html) -[![delta-core](https://img.shields.io/badge/delta--core-2.1.0-blue)](https://github.com/delta-io/delta/releases/tag/v1.2.0) +[![delta-core](https://img.shields.io/badge/delta--core-2.4.0-blue)](https://github.com/delta-io/delta/releases/tag/v2.4.0) [![codecov](https://codecov.io/gh/Qbeast-io/qbeast-spark/branch/main/graph/badge.svg?token=8WO7HGZ4MW)](https://codecov.io/gh/Qbeast-io/qbeast-spark) @@ -170,6 +170,7 @@ Use [Python index visualizer](./utils/visualizer/README.md) for your indexed tab | 0.2.0 | 3.1.x | 3.2.0 | 1.0.0 | | 0.3.x | 3.2.x | 3.3.x | 1.2.x | | 0.4.x | 3.3.x | 3.3.x | 2.1.x | +| 0.5.x | 3.4.x | 3.3.x | 2.4.x | Check [here](https://docs.delta.io/latest/releases.html) for **Delta Lake** and **Apache Spark** version compatibility. diff --git a/src/main/scala/io/qbeast/spark/internal/rules/QbeastAnalysisUtils.scala b/src/main/scala/io/qbeast/spark/internal/rules/QbeastAnalysisUtils.scala index dbc68f642..0577360ad 100644 --- a/src/main/scala/io/qbeast/spark/internal/rules/QbeastAnalysisUtils.scala +++ b/src/main/scala/io/qbeast/spark/internal/rules/QbeastAnalysisUtils.scala @@ -3,24 +3,30 @@ */ package io.qbeast.spark.internal.rules +import org.apache.spark.sql.catalyst.analysis.TableOutputResolver import org.apache.spark.sql.{AnalysisExceptionFactory, SchemaUtils} import org.apache.spark.sql.catalyst.expressions.{ Alias, - AnsiCast, + ArrayTransform, Attribute, Cast, CreateStruct, Expression, + GetArrayItem, GetStructField, + LambdaFunction, NamedExpression, + NamedLambdaVariable, UpCast } import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Project} import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.types.{DataType, StructField, StructType} +import org.apache.spark.sql.types.{ArrayType, DataType, IntegerType, StructField, StructType} private[rules] object QbeastAnalysisUtils { + private lazy val conf = SQLConf.get + /** * Checks if the schema of the Table corresponds to the schema of the Query * From Delta Lake OSS Project code in DeltaAnalysis @@ -72,7 +78,7 @@ private[rules] object QbeastAnalysisUtils { Project(project, query) } - type CastFunction = (Expression, DataType) => Expression + type CastFunction = (Expression, DataType, String) => Expression /** * From DeltaAnalysis code in spark/src/main/scala/org/apache/spark/sql/delta/DeltaAnalysis.scala @@ -80,16 +86,19 @@ private[rules] object QbeastAnalysisUtils { * @return */ def getCastFunction: CastFunction = { - val conf = SQLConf.get val timeZone = conf.sessionLocalTimeZone conf.storeAssignmentPolicy match { case SQLConf.StoreAssignmentPolicy.LEGACY => - Cast(_, _, Option(timeZone), ansiEnabled = false) + (input: Expression, dt: DataType, _) => + Cast(input, dt, Option(timeZone), ansiEnabled = false) case SQLConf.StoreAssignmentPolicy.ANSI => - (input: Expression, dt: DataType) => { - AnsiCast(input, dt, Option(timeZone)) + (input: Expression, dt: DataType, name: String) => { + val cast = Cast(input, dt, Option(timeZone), ansiEnabled = true) + cast.setTagValue(Cast.BY_TABLE_INSERTION, ()) + TableOutputResolver.checkCastOverflowInTableInsert(cast, name) } - case SQLConf.StoreAssignmentPolicy.STRICT => UpCast(_, _) + case SQLConf.StoreAssignmentPolicy.STRICT => + (input: Expression, dt: DataType, _) => UpCast(input, dt) } } @@ -131,7 +140,10 @@ private[rules] object QbeastAnalysisUtils { case (other, i) if i < target.length => val targetAttr = target(i) Alias( - getCastFunction(GetStructField(parent, i, Option(other.name)), targetAttr.dataType), + getCastFunction( + GetStructField(parent, i, Option(other.name)), + targetAttr.dataType, + targetAttr.name), targetAttr.name)(explicitMetadata = Option(targetAttr.metadata)) case (other, i) => @@ -146,6 +158,34 @@ private[rules] object QbeastAnalysisUtils { Option(parent.metadata)) } + /** + * From DeltaAnalysis code in spark/src/main/scala/org/apache/spark/sql/delta/DeltaAnalysis.scala + * + * Recursively add casts to Array[Struct] + * @param tableName the name of the table + * @param parent the parent expression + * @param source the source Struct + * @param target the final target Struct + * @param sourceNullable if source is nullable + * @return + */ + + private def addCastsToArrayStructs( + tableName: String, + parent: NamedExpression, + source: StructType, + target: StructType, + sourceNullable: Boolean): Expression = { + val structConverter: (Expression, Expression) => Expression = (_, i) => + addCastsToStructs(tableName, Alias(GetArrayItem(parent, i), i.toString)(), source, target) + val transformLambdaFunc = { + val elementVar = NamedLambdaVariable("elementVar", source, sourceNullable) + val indexVar = NamedLambdaVariable("indexVar", IntegerType, false) + LambdaFunction(structConverter(elementVar, indexVar), Seq(elementVar, indexVar)) + } + ArrayTransform(parent, transformLambdaFunc) + } + /** * From DeltaAnalysis code in spark/src/main/scala/org/apache/spark/sql/delta/DeltaAnalysis.scala * Adds cast to input/query column from the target table @@ -163,8 +203,11 @@ private[rules] object QbeastAnalysisUtils { attr case (s: StructType, t: StructType) if s != t => addCastsToStructs(tblName, attr, s, t) + case (ArrayType(s: StructType, sNull: Boolean), ArrayType(t: StructType, tNull: Boolean)) + if s != t && sNull == tNull => + addCastsToArrayStructs(tblName, attr, s, t, sNull) case _ => - getCastFunction(attr, targetAttr.dataType) + getCastFunction(attr, targetAttr.dataType, targetAttr.name) } Alias(expr, targetAttr.name)(explicitMetadata = Option(targetAttr.metadata)) }