Skip to content

Commit

Permalink
[CALCITE-6451] Refine nullability of outputs for MINUS and INTERSECT
Browse files Browse the repository at this point in the history
For Minus, the output column nullability is that of the primary input

For Intersect, an output column is nullable if and only if it is
nullable in all of the inputs
  • Loading branch information
vbarua committed Jul 6, 2024
1 parent 8a96095 commit 6238792
Show file tree
Hide file tree
Showing 5 changed files with 245 additions and 0 deletions.
31 changes: 31 additions & 0 deletions core/src/main/java/org/apache/calcite/rel/core/Intersect.java
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,14 @@
import org.apache.calcite.rel.RelNode;
import org.apache.calcite.rel.hint.RelHint;
import org.apache.calcite.rel.metadata.RelMetadataQuery;
import org.apache.calcite.rel.type.RelDataType;
import org.apache.calcite.rel.type.RelDataTypeFactory;
import org.apache.calcite.rel.type.RelDataTypeField;
import org.apache.calcite.sql.SqlKind;

import java.util.Collections;
import java.util.List;
import java.util.stream.Collectors;

/**
* Relational expression that returns the intersection of the rows of its
Expand Down Expand Up @@ -79,4 +83,31 @@ protected Intersect(RelInput input) {
dRows *= 0.25;
return dRows;
}

@Override protected RelDataType deriveRowType() {
final RelDataType leastRestrictiveRowType = deriveLeastRestrictiveRowType();
List<RelDataTypeField> outputFields = leastRestrictiveRowType.getFieldList();

List<List<RelDataTypeField>> inputs = getInputs()
.stream()
.map(i -> i.getRowType().getFieldList())
.collect(Collectors.toList());

// The leastRestrictiveRowType can potentially have columns marked as nullable that do not need
// to be so. An output column is only nullable if it is nullable in ALL the inputs.
final RelDataTypeFactory.Builder typeBuilder =
new RelDataTypeFactory.Builder(getCluster().getTypeFactory());
for (int fieldIndex = 0; fieldIndex < outputFields.size(); fieldIndex++) {
boolean isNullable = outputFields.get(fieldIndex).getType().isNullable();
for (List<RelDataTypeField> input : inputs) {
isNullable &= input.get(fieldIndex).getType().isNullable();
}

typeBuilder
.add(outputFields.get(fieldIndex))
.nullable(isNullable);
}

return typeBuilder.build();
}
}
21 changes: 21 additions & 0 deletions core/src/main/java/org/apache/calcite/rel/core/Minus.java
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,9 @@
import org.apache.calcite.rel.hint.RelHint;
import org.apache.calcite.rel.metadata.RelMdUtil;
import org.apache.calcite.rel.metadata.RelMetadataQuery;
import org.apache.calcite.rel.type.RelDataType;
import org.apache.calcite.rel.type.RelDataTypeFactory;
import org.apache.calcite.rel.type.RelDataTypeField;
import org.apache.calcite.sql.SqlKind;

import java.util.Collections;
Expand Down Expand Up @@ -60,4 +63,22 @@ protected Minus(RelInput input) {
@Override public double estimateRowCount(RelMetadataQuery mq) {
return RelMdUtil.getMinusRowCount(mq, this);
}

@Override protected RelDataType deriveRowType() {
final RelDataType leastRestrictiveRowType = deriveLeastRestrictiveRowType();
List<RelDataTypeField> outputFields = leastRestrictiveRowType.getFieldList();

// The leastRestrictiveRowType can potentially have columns marked as nullable that do not need
// to be so. The nullability of the output columns is the same as that of the primary input.
List<RelDataTypeField> primaryInputFields = getInput(0).getRowType().getFieldList();
final RelDataTypeFactory.Builder typeBuilder =
new RelDataTypeFactory.Builder(getCluster().getTypeFactory());
for (int i = 0; i < primaryInputFields.size(); i++) {
typeBuilder
.add(outputFields.get(i))
.nullable(primaryInputFields.get(i).getType().isNullable());
}

return typeBuilder.build();
}
}
4 changes: 4 additions & 0 deletions core/src/main/java/org/apache/calcite/rel/core/SetOp.java
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,10 @@ public abstract SetOp copy(
}

@Override protected RelDataType deriveRowType() {
return deriveLeastRestrictiveRowType();
}

protected RelDataType deriveLeastRestrictiveRowType() {
final List<RelDataType> inputRowTypes =
Util.transform(inputs, RelNode::getRowType);
final RelDataType rowType =
Expand Down
176 changes: 176 additions & 0 deletions core/src/test/java/org/apache/calcite/test/RelBuilderTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,7 @@
import java.util.stream.Collectors;

import static org.apache.calcite.test.Matchers.hasHints;
import static org.apache.calcite.test.Matchers.hasRelDataType;
import static org.apache.calcite.test.Matchers.hasTree;

import static org.hamcrest.CoreMatchers.allOf;
Expand Down Expand Up @@ -2317,6 +2318,69 @@ private static RelNode groupIdRel(RelBuilder builder, boolean extra) {
assertThat(root, hasTree(expected));
}

@ParameterizedTest
@ValueSource(booleans = {false, true})
void testUnionTypeDerivation(boolean all) {
final RelBuilder builder = RelBuilder.create(config().build());

RelDataType input1RowType =
new RelDataTypeFactory.Builder(builder.getTypeFactory())
.add("a", SqlTypeName.BIGINT)
.nullable(false)
.add("b", SqlTypeName.BIGINT)
.nullable(false)
.add("c", SqlTypeName.BIGINT)
.nullable(true)
.add("d", SqlTypeName.BIGINT)
.nullable(true)
.build();

RelDataType input2RowType =
new RelDataTypeFactory.Builder(builder.getTypeFactory())
.add("a", SqlTypeName.BIGINT)
.nullable(false)
.add("b", SqlTypeName.BIGINT)
.nullable(false)
.add("c", SqlTypeName.BIGINT)
.nullable(false)
.add("d", SqlTypeName.BIGINT)
.nullable(false)
.build();

RelDataType input3RowType =
new RelDataTypeFactory.Builder(builder.getTypeFactory())
.add("a", SqlTypeName.BIGINT)
.nullable(false)
.add("b", SqlTypeName.BIGINT)
.nullable(true)
.add("c", SqlTypeName.BIGINT)
.nullable(false)
.add("d", SqlTypeName.BIGINT)
.nullable(true)
.build();

RelNode root =
builder
.values(input1RowType)
.values(input2RowType)
.values(input3RowType)
.union(all, 3)
.build();

RelDataType expectedRowType =
new RelDataTypeFactory.Builder(builder.getTypeFactory())
.add("a", SqlTypeName.BIGINT)
.nullable(false)
.add("b", SqlTypeName.BIGINT)
.nullable(true)
.add("c", SqlTypeName.BIGINT)
.nullable(true)
.add("d", SqlTypeName.BIGINT)
.nullable(true)
.build();
assertThat(root.getRowType(), hasRelDataType(expectedRowType));
}

/** Test case for
* <a href="https://issues.apache.org/jira/browse/CALCITE-1522">[CALCITE-1522]
* Fix error message for SetOp with incompatible args</a>. */
Expand Down Expand Up @@ -2521,6 +2585,69 @@ private static RelNode groupIdRel(RelBuilder builder, boolean extra) {
assertThat(root, hasTree(expected));
}

@ParameterizedTest
@ValueSource(booleans = {false, true})
void testIntersectTypeDerivation(boolean all) {
final RelBuilder builder = RelBuilder.create(config().build());

RelDataType input1RowType =
new RelDataTypeFactory.Builder(builder.getTypeFactory())
.add("a", SqlTypeName.BIGINT)
.nullable(false)
.add("b", SqlTypeName.BIGINT)
.nullable(false)
.add("c", SqlTypeName.BIGINT)
.nullable(true)
.add("d", SqlTypeName.BIGINT)
.nullable(true)
.build();

RelDataType input2RowType =
new RelDataTypeFactory.Builder(builder.getTypeFactory())
.add("a", SqlTypeName.BIGINT)
.nullable(true)
.add("b", SqlTypeName.BIGINT)
.nullable(true)
.add("c", SqlTypeName.BIGINT)
.nullable(true)
.add("d", SqlTypeName.BIGINT)
.nullable(true)
.build();

RelDataType input3RowType =
new RelDataTypeFactory.Builder(builder.getTypeFactory())
.add("a", SqlTypeName.BIGINT)
.nullable(false)
.add("b", SqlTypeName.BIGINT)
.nullable(true)
.add("c", SqlTypeName.BIGINT)
.nullable(false)
.add("d", SqlTypeName.BIGINT)
.nullable(true)
.build();

RelNode root =
builder
.values(input1RowType)
.values(input2RowType)
.values(input3RowType)
.intersect(all, 3)
.build();

RelDataType expectedRowType =
new RelDataTypeFactory.Builder(builder.getTypeFactory())
.add("a", SqlTypeName.BIGINT)
.nullable(false)
.add("b", SqlTypeName.BIGINT)
.nullable(false)
.add("c", SqlTypeName.BIGINT)
.nullable(false)
.add("d", SqlTypeName.BIGINT)
.nullable(true)
.build();
assertThat(root.getRowType(), hasRelDataType(expectedRowType));
}

@Test void testExcept() {
// Equivalent SQL:
// SELECT empno FROM emp
Expand Down Expand Up @@ -2548,6 +2675,55 @@ private static RelNode groupIdRel(RelBuilder builder, boolean extra) {
assertThat(root, hasTree(expected));
}

@ParameterizedTest
@ValueSource(booleans = {false, true})
void testExceptTypeDerivation(boolean all) {
final RelBuilder builder = RelBuilder.create(config().build());

RelDataType primaryRowType =
new RelDataTypeFactory.Builder(builder.getTypeFactory())
.add("a", SqlTypeName.BIGINT)
.nullable(false)
.add("b", SqlTypeName.BIGINT)
.nullable(false)
.add("c", SqlTypeName.BIGINT)
.nullable(true)
.add("d", SqlTypeName.BIGINT)
.nullable(true)
.build();

RelDataType secondaryRowType =
new RelDataTypeFactory.Builder(builder.getTypeFactory())
.add("a", SqlTypeName.BIGINT)
.nullable(false)
.add("b", SqlTypeName.BIGINT)
.nullable(true)
.add("c", SqlTypeName.BIGINT)
.nullable(false)
.add("d", SqlTypeName.BIGINT)
.nullable(true)
.build();

RelNode root =
builder.values(primaryRowType)
.values(secondaryRowType)
.minus(all)
.build();

RelDataType expectedRowType =
new RelDataTypeFactory.Builder(builder.getTypeFactory())
.add("a", SqlTypeName.BIGINT)
.nullable(false)
.add("b", SqlTypeName.BIGINT)
.nullable(false)
.add("c", SqlTypeName.BIGINT)
.nullable(true)
.add("d", SqlTypeName.BIGINT)
.nullable(true)
.build();
assertThat(root.getRowType(), hasRelDataType(expectedRowType));
}

/** Tests building a simple join. Also checks {@link RelBuilder#size()}
* at every step. */
@Test void testJoin() {
Expand Down
13 changes: 13 additions & 0 deletions testkit/src/main/java/org/apache/calcite/test/Matchers.java
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import org.apache.calcite.rel.RelNode;
import org.apache.calcite.rel.RelValidityChecker;
import org.apache.calcite.rel.hint.Hintable;
import org.apache.calcite.rel.type.RelDataType;
import org.apache.calcite.rex.RexNode;
import org.apache.calcite.util.TestUtil;
import org.apache.calcite.util.Util;
Expand All @@ -34,6 +35,7 @@
import org.hamcrest.Matcher;
import org.hamcrest.TypeSafeMatcher;
import org.hamcrest.core.Is;
import org.hamcrest.core.IsEqual;
import org.hamcrest.core.StringContains;

import java.nio.charset.Charset;
Expand Down Expand Up @@ -245,6 +247,17 @@ public static Matcher<RelNode> hasFieldNames(String fieldNames) {
}
};
}

/**
* Creates a Matcher that matches a {@link RelDataType} if its
* {@link RelDataType#getFullTypeString()} is equal to that of the given {@code relDataType}.
*/
public static Matcher<RelDataType> hasRelDataType(RelDataType relDataType) {
return compose(IsEqual.equalTo(relDataType.getFullTypeString()), input -> {
return input.getFullTypeString();
});
}

/**
* Creates a Matcher that matches a {@link RelNode} if its string
* representation, after converting Windows-style line endings ("\r\n")
Expand Down

0 comments on commit 6238792

Please sign in to comment.