Skip to content

Commit

Permalink
Fix AnsiCast and update README
Browse files Browse the repository at this point in the history
  • Loading branch information
osopardo1 authored and osopardo1 committed Sep 21, 2023
1 parent 061eb14 commit ee619fb
Show file tree
Hide file tree
Showing 2 changed files with 56 additions and 12 deletions.
5 changes: 3 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,9 @@

**Qbeast Spark** is an extension for [**Data Lakehouses**](http://cidrdb.org/cidr2021/papers/cidr2021_paper17.pdf) that enables **multi-dimensional filtering** and **sampling** directly on the storage

[![apache-spark](https://img.shields.io/badge/apache--spark-3.3.x-blue)](https://spark.apache.org/releases/spark-release-3-2-2.html)
[![apache-spark](https://img.shields.io/badge/apache--spark-3.4.x-blue)](https://spark.apache.org/releases/spark-release-3-4-1.html)
[![apache-hadoop](https://img.shields.io/badge/apache--hadoop-3.3.x-blue)](https://hadoop.apache.org/release/3.3.1.html)
[![delta-core](https://img.shields.io/badge/delta--core-2.1.0-blue)](https://github.com/delta-io/delta/releases/tag/v1.2.0)
[![delta-core](https://img.shields.io/badge/delta--core-2.4.0-blue)](https://github.com/delta-io/delta/releases/tag/v2.4.0)
[![codecov](https://codecov.io/gh/Qbeast-io/qbeast-spark/branch/main/graph/badge.svg?token=8WO7HGZ4MW)](https://codecov.io/gh/Qbeast-io/qbeast-spark)

</div>
Expand Down Expand Up @@ -170,6 +170,7 @@ Use [Python index visualizer](./utils/visualizer/README.md) for your indexed tab
| 0.2.0 | 3.1.x | 3.2.0 | 1.0.0 |
| 0.3.x | 3.2.x | 3.3.x | 1.2.x |
| 0.4.x | 3.3.x | 3.3.x | 2.1.x |
| 0.5.x | 3.4.x | 3.3.x | 2.4.x |

Check [here](https://docs.delta.io/latest/releases.html) for **Delta Lake** and **Apache Spark** version compatibility.

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,24 +3,30 @@
*/
package io.qbeast.spark.internal.rules

import org.apache.spark.sql.catalyst.analysis.TableOutputResolver
import org.apache.spark.sql.{AnalysisExceptionFactory, SchemaUtils}
import org.apache.spark.sql.catalyst.expressions.{
Alias,
AnsiCast,
ArrayTransform,
Attribute,
Cast,
CreateStruct,
Expression,
GetArrayItem,
GetStructField,
LambdaFunction,
NamedExpression,
NamedLambdaVariable,
UpCast
}
import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Project}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types.{DataType, StructField, StructType}
import org.apache.spark.sql.types.{ArrayType, DataType, IntegerType, StructField, StructType}

private[rules] object QbeastAnalysisUtils {

private lazy val conf = SQLConf.get

/**
* Checks if the schema of the Table corresponds to the schema of the Query
* From Delta Lake OSS Project code in DeltaAnalysis
Expand Down Expand Up @@ -72,24 +78,27 @@ private[rules] object QbeastAnalysisUtils {
Project(project, query)
}

type CastFunction = (Expression, DataType) => Expression
type CastFunction = (Expression, DataType, String) => Expression

/**
* From DeltaAnalysis code in spark/src/main/scala/org/apache/spark/sql/delta/DeltaAnalysis.scala
* Get cast operation for the level of strictness in the schema a user asked for
* @return
*/
def getCastFunction: CastFunction = {
val conf = SQLConf.get
val timeZone = conf.sessionLocalTimeZone
conf.storeAssignmentPolicy match {
case SQLConf.StoreAssignmentPolicy.LEGACY =>
Cast(_, _, Option(timeZone), ansiEnabled = false)
(input: Expression, dt: DataType, _) =>
Cast(input, dt, Option(timeZone), ansiEnabled = false)
case SQLConf.StoreAssignmentPolicy.ANSI =>
(input: Expression, dt: DataType) => {
AnsiCast(input, dt, Option(timeZone))
(input: Expression, dt: DataType, name: String) => {
val cast = Cast(input, dt, Option(timeZone), ansiEnabled = true)
cast.setTagValue(Cast.BY_TABLE_INSERTION, ())
TableOutputResolver.checkCastOverflowInTableInsert(cast, name)
}
case SQLConf.StoreAssignmentPolicy.STRICT => UpCast(_, _)
case SQLConf.StoreAssignmentPolicy.STRICT =>
(input: Expression, dt: DataType, _) => UpCast(input, dt)
}
}

Expand Down Expand Up @@ -131,7 +140,10 @@ private[rules] object QbeastAnalysisUtils {
case (other, i) if i < target.length =>
val targetAttr = target(i)
Alias(
getCastFunction(GetStructField(parent, i, Option(other.name)), targetAttr.dataType),
getCastFunction(
GetStructField(parent, i, Option(other.name)),
targetAttr.dataType,
targetAttr.name),
targetAttr.name)(explicitMetadata = Option(targetAttr.metadata))

case (other, i) =>
Expand All @@ -146,6 +158,34 @@ private[rules] object QbeastAnalysisUtils {
Option(parent.metadata))
}

/**
* From DeltaAnalysis code in spark/src/main/scala/org/apache/spark/sql/delta/DeltaAnalysis.scala
*
* Recursively add casts to Array[Struct]
* @param tableName the name of the table
* @param parent the parent expression
* @param source the source Struct
* @param target the final target Struct
* @param sourceNullable if source is nullable
* @return
*/

private def addCastsToArrayStructs(
tableName: String,
parent: NamedExpression,
source: StructType,
target: StructType,
sourceNullable: Boolean): Expression = {
val structConverter: (Expression, Expression) => Expression = (_, i) =>
addCastsToStructs(tableName, Alias(GetArrayItem(parent, i), i.toString)(), source, target)
val transformLambdaFunc = {
val elementVar = NamedLambdaVariable("elementVar", source, sourceNullable)
val indexVar = NamedLambdaVariable("indexVar", IntegerType, false)
LambdaFunction(structConverter(elementVar, indexVar), Seq(elementVar, indexVar))
}
ArrayTransform(parent, transformLambdaFunc)
}

/**
* From DeltaAnalysis code in spark/src/main/scala/org/apache/spark/sql/delta/DeltaAnalysis.scala
* Adds cast to input/query column from the target table
Expand All @@ -163,8 +203,11 @@ private[rules] object QbeastAnalysisUtils {
attr
case (s: StructType, t: StructType) if s != t =>
addCastsToStructs(tblName, attr, s, t)
case (ArrayType(s: StructType, sNull: Boolean), ArrayType(t: StructType, tNull: Boolean))
if s != t && sNull == tNull =>
addCastsToArrayStructs(tblName, attr, s, t, sNull)
case _ =>
getCastFunction(attr, targetAttr.dataType)
getCastFunction(attr, targetAttr.dataType, targetAttr.name)
}
Alias(expr, targetAttr.name)(explicitMetadata = Option(targetAttr.metadata))
}
Expand Down

0 comments on commit ee619fb

Please sign in to comment.