Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Kernel] New tests for Timestamp_NTZ writes #4208

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,25 @@ public class Protocol {
/// Public static variables and methods ///
/////////////////////////////////////////////////////////////////////////////////////////////////

/**
* Helper method to get the Protocol from the row representation.
*
* @param row Row representation of the Protocol.
* @return the Protocol object
*/
public static Protocol fromRow(Row row) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: use requireNonNull to assert that row isn't null before we access row.issNullAt etc

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

requireNonNull(row);
Set<String> readerFeatures =
row.isNullAt(2)
? Collections.emptySet()
: Collections.unmodifiableSet(new HashSet<>(VectorUtils.toJavaList(row.getArray(2))));
Set<String> writerFeatures =
row.isNullAt(3)
? Collections.emptySet()
: Collections.unmodifiableSet(new HashSet<>(VectorUtils.toJavaList(row.getArray(3))));
return new Protocol(row.getInt(0), row.getInt(1), readerFeatures, writerFeatures);
}

public static Protocol fromColumnVector(ColumnVector vector, int rowId) {
if (vector.isNullAt(rowId)) {
return null;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,10 @@ package io.delta.kernel.internal.actions

import scala.collection.JavaConverters._

import io.delta.kernel.internal.data.GenericRow
import io.delta.kernel.internal.tablefeatures.TableFeatures
import io.delta.kernel.internal.util.VectorUtils
import io.delta.kernel.types.{ArrayType, IntegerType, StringType, StructType}

import org.scalatest.funsuite.AnyFunSuite

Expand Down Expand Up @@ -369,4 +372,21 @@ class ProtocolSuite extends AnyFunSuite {
assert(merged.getWriterFeatures.asScala === expWriterFeatures)
}
})

test("extract protocol from the row representation") {
val ordinalToValue: Map[Integer, Object] = Map(
Integer.valueOf(0) -> Integer.valueOf(42),
Integer.valueOf(1) -> Integer.valueOf(43),
Integer.valueOf(2) -> VectorUtils.stringArrayValue(List("foo").asJava).asInstanceOf[Object],
Integer.valueOf(3) -> VectorUtils.stringArrayValue(List("bar").asJava).asInstanceOf[Object])
val row = new GenericRow(
new StructType().add("minReaderVersion", IntegerType.INTEGER)
.add("minWriterVersion", IntegerType.INTEGER)
.add("readerFeatures", new ArrayType(StringType.STRING, true))
.add("writerFeatures", new ArrayType(StringType.STRING, true)),
ordinalToValue.asJava)

val expected = new Protocol(42, 43, Set("foo").asJava, Set("bar").asJava)
assert(Protocol.fromRow(row) === expected)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,20 @@
*/
package io.delta.kernel.defaults

import java.util.Collections

import scala.collection.immutable.Seq
import scala.jdk.CollectionConverters._

import io.delta.kernel.engine.Engine
import io.delta.kernel.Operation.CREATE_TABLE
import io.delta.kernel.Table
import io.delta.kernel.expressions.Literal
import io.delta.kernel.internal.actions.{Protocol => KernelProtocol}
import io.delta.kernel.types.{StructType, TimestampNTZType}
import io.delta.kernel.types.IntegerType.INTEGER
import io.delta.kernel.utils.CloseableIterable.emptyIterable

import org.apache.spark.sql.delta.DeltaLog
import org.apache.spark.sql.delta.{DeltaLog, DeltaTableFeatureException}
import org.apache.spark.sql.delta.actions.Protocol

/**
Expand Down Expand Up @@ -132,6 +140,61 @@ class DeltaTableFeaturesSuite extends DeltaTableWriteSuiteBase {
}
}

// Test format: isTimestampNtzEnabled, expected protocol.
Seq(
(true, new KernelProtocol(3, 7, Set("timestampNtz").asJava, Set("timestampNtz").asJava)),
(false, new KernelProtocol(1, 2, Collections.emptySet(), Collections.emptySet())))
.foreach({
case (isTimestampNtzEnabled, expectedProtocol) =>
test(s"Create table with timestampNtz enabled: $isTimestampNtzEnabled") {
withTempDirAndEngine { (tablePath, engine) =>
val table = Table.forPath(engine, tablePath)
val txnBuilder = table.createTransactionBuilder(engine, testEngineInfo, CREATE_TABLE)

val schema = if (isTimestampNtzEnabled) {
new StructType().add("tz", TimestampNTZType.TIMESTAMP_NTZ)
} else {
new StructType().add("id", INTEGER)
}
val txn = txnBuilder
.withSchema(engine, schema)
.build(engine)

assert(txn.getSchema(engine) === schema)
assert(txn.getPartitionColumns(engine).isEmpty)
val txnResult = commitTransaction(txn, engine, emptyIterable())

assert(txnResult.getVersion === 0)
val protocolRow = getProtocolActionFromCommit(engine, table, 0)
assert(protocolRow.isDefined)
val protocol = KernelProtocol.fromRow(protocolRow.get)
assert(protocol.getMinReaderVersion === expectedProtocol.getMinReaderVersion)
assert(protocol.getMinWriterVersion === expectedProtocol.getMinWriterVersion)
assert(protocol.getReaderFeatures.containsAll(expectedProtocol.getReaderFeatures))
assert(protocol.getWriterFeatures.containsAll(expectedProtocol.getWriterFeatures))
}
}
})

test("schema evolution from Spark to add TIMESTAMP_NTZ type on a table created with kernel") {
withTempDirAndEngine { (tablePath, engine) =>
val table = Table.forPath(engine, tablePath)
val txnBuilder = table.createTransactionBuilder(engine, testEngineInfo, CREATE_TABLE)
val txn = txnBuilder
.withSchema(engine, testSchema)
.build(engine)
val txnResult = commitTransaction(txn, engine, emptyIterable())

assert(txnResult.getVersion === 0)
assertThrows[DeltaTableFeatureException] {
spark.sql("ALTER TABLE delta.`" + tablePath + "` ADD COLUMN newCol TIMESTAMP_NTZ")
}
spark.sql("ALTER TABLE delta.`" + tablePath +
"` SET TBLPROPERTIES ('delta.feature.timestampNtz' = 'supported')")
spark.sql("ALTER TABLE delta.`" + tablePath + "` ADD COLUMN newCol TIMESTAMP_NTZ")
}
}

///////////////////////////////////////////////////////////////////////////
// Helper methods
///////////////////////////////////////////////////////////////////////////
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,18 +26,16 @@ import io.delta.golden.GoldenTableUtils.goldenTablePath
import io.delta.kernel.{Meta, Operation, Table, Transaction, TransactionBuilder, TransactionCommitResult}
import io.delta.kernel.Operation.CREATE_TABLE
import io.delta.kernel.data.{ColumnarBatch, ColumnVector, FilteredColumnarBatch, Row}
import io.delta.kernel.defaults.engine.DefaultEngine
import io.delta.kernel.defaults.internal.data.DefaultColumnarBatch
import io.delta.kernel.defaults.utils.{TestRow, TestUtils}
import io.delta.kernel.engine.Engine
import io.delta.kernel.expressions.Literal
import io.delta.kernel.expressions.Literal.ofInt
import io.delta.kernel.hook.PostCommitHook.PostCommitHookType
import io.delta.kernel.internal.{SnapshotImpl, TableConfig, TableImpl}
import io.delta.kernel.internal.actions.{Metadata, Protocol, SingleAction}
import io.delta.kernel.internal.actions.SingleAction
import io.delta.kernel.internal.fs.{Path => DeltaPath}
import io.delta.kernel.internal.util.Clock
import io.delta.kernel.internal.util.FileNames
import io.delta.kernel.internal.util.{Clock, FileNames, VectorUtils}
import io.delta.kernel.internal.util.SchemaUtils.casePreservingPartitionColNames
import io.delta.kernel.internal.util.Utils.singletonCloseableIterator
import io.delta.kernel.internal.util.Utils.toCloseableIterator
Expand Down
Loading