Skip to content

Commit

Permalink
[CALCITE-6451] Refine nullability of outputs for MINUS and INTERSECT
Browse files Browse the repository at this point in the history
  • Loading branch information
vbarua committed Jul 5, 2024
1 parent 8a96095 commit 3259727
Show file tree
Hide file tree
Showing 5 changed files with 243 additions and 0 deletions.
30 changes: 30 additions & 0 deletions core/src/main/java/org/apache/calcite/rel/core/Intersect.java
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,14 @@
import org.apache.calcite.rel.RelNode;
import org.apache.calcite.rel.hint.RelHint;
import org.apache.calcite.rel.metadata.RelMetadataQuery;
import org.apache.calcite.rel.type.RelDataType;
import org.apache.calcite.rel.type.RelDataTypeFactory;
import org.apache.calcite.rel.type.RelDataTypeField;
import org.apache.calcite.sql.SqlKind;

import java.util.Collections;
import java.util.List;
import java.util.stream.Collectors;

/**
* Relational expression that returns the intersection of the rows of its
Expand Down Expand Up @@ -79,4 +83,30 @@ protected Intersect(RelInput input) {
dRows *= 0.25;
return dRows;
}

@Override protected RelDataType deriveRowType() {
final RelDataType leastRestrictiveRowType = deriveLeastRestrictiveRowType();
List<RelDataTypeField> outputFields = leastRestrictiveRowType.getFieldList();

List<List<RelDataTypeField>> inputs = getInputs()
.stream()
.map(i -> i.getRowType().getFieldList())
.collect(Collectors.toList());

// The leastRestrictiveRowType can potentially have columns marked as nullable that do not need to be so.
// An output column is only nullable if it is nullable in ALL of the inputs.
final RelDataTypeFactory.Builder typeBuilder = new RelDataTypeFactory.Builder(getCluster().getTypeFactory());
for (int fieldIndex = 0; fieldIndex < outputFields.size(); fieldIndex++) {
boolean isNullable = outputFields.get(fieldIndex).getType().isNullable();
for (List<RelDataTypeField> input : inputs) {
isNullable &= input.get(fieldIndex).getType().isNullable();
}

typeBuilder
.add(outputFields.get(fieldIndex))
.nullable(isNullable);
}

return typeBuilder.build();
}
}
20 changes: 20 additions & 0 deletions core/src/main/java/org/apache/calcite/rel/core/Minus.java
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,9 @@
import org.apache.calcite.rel.hint.RelHint;
import org.apache.calcite.rel.metadata.RelMdUtil;
import org.apache.calcite.rel.metadata.RelMetadataQuery;
import org.apache.calcite.rel.type.RelDataType;
import org.apache.calcite.rel.type.RelDataTypeFactory;
import org.apache.calcite.rel.type.RelDataTypeField;
import org.apache.calcite.sql.SqlKind;

import java.util.Collections;
Expand Down Expand Up @@ -60,4 +63,21 @@ protected Minus(RelInput input) {
@Override public double estimateRowCount(RelMetadataQuery mq) {
return RelMdUtil.getMinusRowCount(mq, this);
}

@Override protected RelDataType deriveRowType() {
final RelDataType leastRestrictiveRowType = deriveLeastRestrictiveRowType();
List<RelDataTypeField> outputFields = leastRestrictiveRowType.getFieldList();

// The leastRestrictiveRowType can potentially have columns marked as nullable that do not need to be so.
// The nullability of the output columns is the same as that of the primary input.
List<RelDataTypeField> primaryInputFields = getInput(0).getRowType().getFieldList();
final RelDataTypeFactory.Builder typeBuilder = new RelDataTypeFactory.Builder(getCluster().getTypeFactory());
for (int i = 0; i < primaryInputFields.size(); i++) {
typeBuilder
.add(outputFields.get(i))
.nullable(primaryInputFields.get(i).getType().isNullable());
}

return typeBuilder.build();
}
}
4 changes: 4 additions & 0 deletions core/src/main/java/org/apache/calcite/rel/core/SetOp.java
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,10 @@ public abstract SetOp copy(
}

@Override protected RelDataType deriveRowType() {
return deriveLeastRestrictiveRowType();
}

protected RelDataType deriveLeastRestrictiveRowType() {
final List<RelDataType> inputRowTypes =
Util.transform(inputs, RelNode::getRowType);
final RelDataType rowType =
Expand Down
176 changes: 176 additions & 0 deletions core/src/test/java/org/apache/calcite/test/RelBuilderTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,7 @@
import java.util.stream.Collectors;

import static org.apache.calcite.test.Matchers.hasHints;
import static org.apache.calcite.test.Matchers.hasRelDataType;
import static org.apache.calcite.test.Matchers.hasTree;

import static org.hamcrest.CoreMatchers.allOf;
Expand Down Expand Up @@ -2317,6 +2318,69 @@ private static RelNode groupIdRel(RelBuilder builder, boolean extra) {
assertThat(root, hasTree(expected));
}

@ParameterizedTest
@ValueSource(booleans = {false, true})
void testUnionTypeDerivation(boolean all) {
final RelBuilder builder = RelBuilder.create(config().build());

RelDataType input1RowType =
new RelDataTypeFactory.Builder(builder.getTypeFactory())
.add("a", SqlTypeName.BIGINT)
.nullable(false)
.add("b", SqlTypeName.BIGINT)
.nullable(false)
.add("c", SqlTypeName.BIGINT)
.nullable(true)
.add("d", SqlTypeName.BIGINT)
.nullable(true)
.build();

RelDataType input2RowType =
new RelDataTypeFactory.Builder(builder.getTypeFactory())
.add("a", SqlTypeName.BIGINT)
.nullable(false)
.add("b", SqlTypeName.BIGINT)
.nullable(false)
.add("c", SqlTypeName.BIGINT)
.nullable(false)
.add("d", SqlTypeName.BIGINT)
.nullable(false)
.build();

RelDataType input3RowType =
new RelDataTypeFactory.Builder(builder.getTypeFactory())
.add("a", SqlTypeName.BIGINT)
.nullable(false)
.add("b", SqlTypeName.BIGINT)
.nullable(true)
.add("c", SqlTypeName.BIGINT)
.nullable(false)
.add("d", SqlTypeName.BIGINT)
.nullable(true)
.build();

RelNode root =
builder
.values(input1RowType)
.values(input2RowType)
.values(input3RowType)
.union(all, 3)
.build();

RelDataType expectedRowType =
new RelDataTypeFactory.Builder(builder.getTypeFactory())
.add("a", SqlTypeName.BIGINT)
.nullable(false)
.add("b", SqlTypeName.BIGINT)
.nullable(true)
.add("c", SqlTypeName.BIGINT)
.nullable(true)
.add("d", SqlTypeName.BIGINT)
.nullable(true)
.build();
assertThat(root.getRowType(), hasRelDataType(expectedRowType));
}

/** Test case for
* <a href="https://issues.apache.org/jira/browse/CALCITE-1522">[CALCITE-1522]
* Fix error message for SetOp with incompatible args</a>. */
Expand Down Expand Up @@ -2521,6 +2585,69 @@ private static RelNode groupIdRel(RelBuilder builder, boolean extra) {
assertThat(root, hasTree(expected));
}

@ParameterizedTest
@ValueSource(booleans = {false, true})
void testIntersectTypeDerivation(boolean all) {
final RelBuilder builder = RelBuilder.create(config().build());

RelDataType input1RowType =
new RelDataTypeFactory.Builder(builder.getTypeFactory())
.add("a", SqlTypeName.BIGINT)
.nullable(false)
.add("b", SqlTypeName.BIGINT)
.nullable(false)
.add("c", SqlTypeName.BIGINT)
.nullable(true)
.add("d", SqlTypeName.BIGINT)
.nullable(true)
.build();

RelDataType input2RowType =
new RelDataTypeFactory.Builder(builder.getTypeFactory())
.add("a", SqlTypeName.BIGINT)
.nullable(true)
.add("b", SqlTypeName.BIGINT)
.nullable(true)
.add("c", SqlTypeName.BIGINT)
.nullable(true)
.add("d", SqlTypeName.BIGINT)
.nullable(true)
.build();

RelDataType input3RowType =
new RelDataTypeFactory.Builder(builder.getTypeFactory())
.add("a", SqlTypeName.BIGINT)
.nullable(false)
.add("b", SqlTypeName.BIGINT)
.nullable(true)
.add("c", SqlTypeName.BIGINT)
.nullable(false)
.add("d", SqlTypeName.BIGINT)
.nullable(true)
.build();

RelNode root =
builder
.values(input1RowType)
.values(input2RowType)
.values(input3RowType)
.intersect(all, 3)
.build();

RelDataType expectedRowType =
new RelDataTypeFactory.Builder(builder.getTypeFactory())
.add("a", SqlTypeName.BIGINT)
.nullable(false)
.add("b", SqlTypeName.BIGINT)
.nullable(false)
.add("c", SqlTypeName.BIGINT)
.nullable(false)
.add("d", SqlTypeName.BIGINT)
.nullable(true)
.build();
assertThat(root.getRowType(), hasRelDataType(expectedRowType));
}

@Test void testExcept() {
// Equivalent SQL:
// SELECT empno FROM emp
Expand Down Expand Up @@ -2548,6 +2675,55 @@ private static RelNode groupIdRel(RelBuilder builder, boolean extra) {
assertThat(root, hasTree(expected));
}

@ParameterizedTest
@ValueSource(booleans = {false, true})
void testExceptTypeDerivation(boolean all) {
final RelBuilder builder = RelBuilder.create(config().build());

RelDataType primaryRowType =
new RelDataTypeFactory.Builder(builder.getTypeFactory())
.add("a", SqlTypeName.BIGINT)
.nullable(false)
.add("b", SqlTypeName.BIGINT)
.nullable(false)
.add("c", SqlTypeName.BIGINT)
.nullable(true)
.add("d", SqlTypeName.BIGINT)
.nullable(true)
.build();

RelDataType secondaryRowType =
new RelDataTypeFactory.Builder(builder.getTypeFactory())
.add("a", SqlTypeName.BIGINT)
.nullable(false)
.add("b", SqlTypeName.BIGINT)
.nullable(true)
.add("c", SqlTypeName.BIGINT)
.nullable(false)
.add("d", SqlTypeName.BIGINT)
.nullable(true)
.build();

RelNode root =
builder.values(primaryRowType)
.values(secondaryRowType)
.minus(all)
.build();

RelDataType expectedRowType =
new RelDataTypeFactory.Builder(builder.getTypeFactory())
.add("a", SqlTypeName.BIGINT)
.nullable(false)
.add("b", SqlTypeName.BIGINT)
.nullable(false)
.add("c", SqlTypeName.BIGINT)
.nullable(true)
.add("d", SqlTypeName.BIGINT)
.nullable(true)
.build();
assertThat(root.getRowType(), hasRelDataType(expectedRowType));
}

/** Tests building a simple join. Also checks {@link RelBuilder#size()}
* at every step. */
@Test void testJoin() {
Expand Down
13 changes: 13 additions & 0 deletions testkit/src/main/java/org/apache/calcite/test/Matchers.java
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import org.apache.calcite.rel.RelNode;
import org.apache.calcite.rel.RelValidityChecker;
import org.apache.calcite.rel.hint.Hintable;
import org.apache.calcite.rel.type.RelDataType;
import org.apache.calcite.rex.RexNode;
import org.apache.calcite.util.TestUtil;
import org.apache.calcite.util.Util;
Expand All @@ -34,6 +35,7 @@
import org.hamcrest.Matcher;
import org.hamcrest.TypeSafeMatcher;
import org.hamcrest.core.Is;
import org.hamcrest.core.IsEqual;
import org.hamcrest.core.StringContains;

import java.nio.charset.Charset;
Expand Down Expand Up @@ -245,6 +247,17 @@ public static Matcher<RelNode> hasFieldNames(String fieldNames) {
}
};
}

/**
* Creates a Matcher that matches a {@link RelDataType} if its
* {@link RelDataType#getFullTypeString()} is equal to that of the given {@code relDataType}
*/
public static Matcher<RelDataType> hasRelDataType(RelDataType relDataType) {
return compose(IsEqual.equalTo(relDataType.getFullTypeString()), input -> {
return input.getFullTypeString();
});
}

/**
* Creates a Matcher that matches a {@link RelNode} if its string
* representation, after converting Windows-style line endings ("\r\n")
Expand Down

0 comments on commit 3259727

Please sign in to comment.