Skip to content

Commit

Permalink
[CALCITE-2195] AggregateJoinTransposeRule fails to aggregate over uni…
Browse files Browse the repository at this point in the history
…que column (Zhong Yu)

Close apache#637
  • Loading branch information
yuzhong authored and julianhyde committed Mar 2, 2018
1 parent 483c0a6 commit 73a09a6
Show file tree
Hide file tree
Showing 3 changed files with 81 additions and 11 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,11 @@ public void onMatch(RelOptRuleCall call) {
for (Ord<Integer> c : Ord.zip(belowAggregateKeyNotShifted)) {
map.put(c.e, belowOffset + c.i);
}
final Mappings.TargetMapping mapping =
s == 0
? Mappings.createIdentity(fieldCount)
: Mappings.createShiftMapping(fieldCount + offset, 0, offset,
fieldCount);
final ImmutableBitSet belowAggregateKey =
belowAggregateKeyNotShifted.shift(-offset);
final boolean unique;
Expand Down Expand Up @@ -224,17 +229,35 @@ public void onMatch(RelOptRuleCall call) {
if (unique) {
++uniqueCount;
side.aggregate = false;
side.newInput = joinInput;
relBuilder.push(joinInput);
final List<RexNode> projects = new ArrayList<>();
for (Integer i : belowAggregateKey) {
projects.add(relBuilder.field(i));
}
for (Ord<AggregateCall> aggCall : Ord.zip(aggregate.getAggCallList())) {
final SqlAggFunction aggregation = aggCall.e.getAggregation();
final SqlSplittableAggFunction splitter =
Preconditions.checkNotNull(
aggregation.unwrap(SqlSplittableAggFunction.class));
if (!aggCall.e.getArgList().isEmpty()
&& fieldSet.contains(ImmutableBitSet.of(aggCall.e.getArgList()))) {
final RexNode singleton = splitter.singleton(rexBuilder,
joinInput.getRowType(), aggCall.e.transform(mapping));
if (singleton instanceof RexInputRef) {
side.split.put(aggCall.i, ((RexInputRef) singleton).getIndex());
} else {
projects.add(singleton);
side.split.put(aggCall.i, projects.size() - 1);
}
}
}
relBuilder.project(projects);
side.newInput = relBuilder.build();
} else {
side.aggregate = true;
List<AggregateCall> belowAggCalls = new ArrayList<>();
final SqlSplittableAggFunction.Registry<AggregateCall>
belowAggCallRegistry = registry(belowAggCalls);
final Mappings.TargetMapping mapping =
s == 0
? Mappings.createIdentity(fieldCount)
: Mappings.createShiftMapping(fieldCount + offset, 0, offset,
fieldCount);
final int oldGroupKeyCount = aggregate.getGroupCount();
final int newGroupKeyCount = belowAggregateKey.cardinality();
for (Ord<AggregateCall> aggCall : Ord.zip(aggregate.getAggCallList())) {
Expand Down
17 changes: 17 additions & 0 deletions core/src/test/java/org/apache/calcite/test/RelOptRulesTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -3010,6 +3010,23 @@ private void transitiveInference(RelOptRule... extraRules) throws Exception {
sql(sql).withPre(preProgram).with(program).check();
}

/** Test case for
* <a href="https://issues.apache.org/jira/browse/CALCITE-2195">[CALCITE-2195]
* AggregateJoinTransposeRule fails to aggregate over unique column</a>. */
@Test public void testPushAggregateThroughJoin6() {
final HepProgram preProgram = new HepProgramBuilder()
.addRuleInstance(AggregateProjectMergeRule.INSTANCE)
.build();
final HepProgram program = new HepProgramBuilder()
.addRuleInstance(AggregateJoinTransposeRule.EXTENDED)
.build();
final String sql = "select sum(B.sal)\n"
+ "from sales.emp as A\n"
+ "join (select distinct sal from sales.emp) as B\n"
+ "on A.sal=B.sal\n";
sql(sql).withPre(preProgram).with(program).check();
}

/** SUM is the easiest aggregate function to split. */
@Test public void testPushAggregateSumThroughJoin() {
final HepProgram preProgram = new HepProgramBuilder()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5806,7 +5806,8 @@ LogicalProject(DEPTNO=[$0])
LogicalJoin(condition=[=($0, $1)], joinType=[inner])
LogicalAggregate(group=[{7}])
LogicalTableScan(table=[[CATALOG, SALES, EMP]])
LogicalTableScan(table=[[CATALOG, SALES, DEPT]])
LogicalProject(DEPTNO=[$0])
LogicalTableScan(table=[[CATALOG, SALES, DEPT]])
]]>
</Resource>
</TestCase>
Expand All @@ -5828,11 +5829,40 @@ LogicalProject(DEPTNO=[$0], DEPTNO0=[$1])
<Resource name="planAfter">
<![CDATA[
LogicalProject(DEPTNO=[$0], DEPTNO0=[$1])
LogicalProject(DEPTNO=[$0], DEPTNO0=[$1])
LogicalJoin(condition=[=($0, $1)], joinType=[inner])
LogicalAggregate(group=[{7}])
LogicalTableScan(table=[[CATALOG, SALES, EMP]])
LogicalJoin(condition=[=($0, $1)], joinType=[inner])
LogicalAggregate(group=[{7}])
LogicalTableScan(table=[[CATALOG, SALES, EMP]])
LogicalProject(DEPTNO=[$0])
LogicalTableScan(table=[[CATALOG, SALES, DEPT]])
]]>
</Resource>
</TestCase>
<TestCase name="testPushAggregateThroughJoin6">
<Resource name="sql">
<![CDATA[select sum(B.sal)
from sales.emp as A
join (select distinct sal from sales.emp) as B
on A.sal=B.sal
]]>
</Resource>
<Resource name="planBefore">
<![CDATA[
LogicalAggregate(group=[{}], EXPR$0=[SUM($9)])
LogicalJoin(condition=[=($5, $9)], joinType=[inner])
LogicalTableScan(table=[[CATALOG, SALES, EMP]])
LogicalAggregate(group=[{5}])
LogicalTableScan(table=[[CATALOG, SALES, EMP]])
]]>
</Resource>
<Resource name="planAfter">
<![CDATA[
LogicalAggregate(group=[{}], EXPR$0=[SUM($3)])
LogicalProject(SAL=[$0], $f1=[$1], SAL0=[$2], $f3=[CAST(*($1, $2)):INTEGER])
LogicalJoin(condition=[=($0, $2)], joinType=[inner])
LogicalAggregate(group=[{5}], agg#0=[COUNT()])
LogicalTableScan(table=[[CATALOG, SALES, EMP]])
LogicalAggregate(group=[{5}])
LogicalTableScan(table=[[CATALOG, SALES, EMP]])
]]>
</Resource>
</TestCase>
Expand Down

0 comments on commit 73a09a6

Please sign in to comment.