From 73a09a691a67eec46349c94b21cf3fa483775c02 Mon Sep 17 00:00:00 2001 From: yuzhong Date: Tue, 27 Feb 2018 22:41:14 +0800 Subject: [PATCH] [CALCITE-2195] AggregateJoinTransposeRule fails to aggregate over unique column (Zhong Yu) Close apache/calcite#637 --- .../rel/rules/AggregateJoinTransposeRule.java | 35 +++++++++++++--- .../apache/calcite/test/RelOptRulesTest.java | 17 ++++++++ .../apache/calcite/test/RelOptRulesTest.xml | 40 ++++++++++++++++--- 3 files changed, 81 insertions(+), 11 deletions(-) diff --git a/core/src/main/java/org/apache/calcite/rel/rules/AggregateJoinTransposeRule.java b/core/src/main/java/org/apache/calcite/rel/rules/AggregateJoinTransposeRule.java index d7c86aa7b7e..10687024b3d 100644 --- a/core/src/main/java/org/apache/calcite/rel/rules/AggregateJoinTransposeRule.java +++ b/core/src/main/java/org/apache/calcite/rel/rules/AggregateJoinTransposeRule.java @@ -197,6 +197,11 @@ public void onMatch(RelOptRuleCall call) { for (Ord c : Ord.zip(belowAggregateKeyNotShifted)) { map.put(c.e, belowOffset + c.i); } + final Mappings.TargetMapping mapping = + s == 0 + ? Mappings.createIdentity(fieldCount) + : Mappings.createShiftMapping(fieldCount + offset, 0, offset, + fieldCount); final ImmutableBitSet belowAggregateKey = belowAggregateKeyNotShifted.shift(-offset); final boolean unique; @@ -224,17 +229,35 @@ public void onMatch(RelOptRuleCall call) { if (unique) { ++uniqueCount; side.aggregate = false; - side.newInput = joinInput; + relBuilder.push(joinInput); + final List projects = new ArrayList<>(); + for (Integer i : belowAggregateKey) { + projects.add(relBuilder.field(i)); + } + for (Ord aggCall : Ord.zip(aggregate.getAggCallList())) { + final SqlAggFunction aggregation = aggCall.e.getAggregation(); + final SqlSplittableAggFunction splitter = + Preconditions.checkNotNull( + aggregation.unwrap(SqlSplittableAggFunction.class)); + if (!aggCall.e.getArgList().isEmpty() + && fieldSet.contains(ImmutableBitSet.of(aggCall.e.getArgList()))) { + final RexNode singleton = splitter.singleton(rexBuilder, + joinInput.getRowType(), aggCall.e.transform(mapping)); + if (singleton instanceof RexInputRef) { + side.split.put(aggCall.i, ((RexInputRef) singleton).getIndex()); + } else { + projects.add(singleton); + side.split.put(aggCall.i, projects.size() - 1); + } + } + } + relBuilder.project(projects); + side.newInput = relBuilder.build(); } else { side.aggregate = true; List belowAggCalls = new ArrayList<>(); final SqlSplittableAggFunction.Registry belowAggCallRegistry = registry(belowAggCalls); - final Mappings.TargetMapping mapping = - s == 0 - ? Mappings.createIdentity(fieldCount) - : Mappings.createShiftMapping(fieldCount + offset, 0, offset, - fieldCount); final int oldGroupKeyCount = aggregate.getGroupCount(); final int newGroupKeyCount = belowAggregateKey.cardinality(); for (Ord aggCall : Ord.zip(aggregate.getAggCallList())) { diff --git a/core/src/test/java/org/apache/calcite/test/RelOptRulesTest.java b/core/src/test/java/org/apache/calcite/test/RelOptRulesTest.java index 2e2a5f64a92..cc995e63b27 100644 --- a/core/src/test/java/org/apache/calcite/test/RelOptRulesTest.java +++ b/core/src/test/java/org/apache/calcite/test/RelOptRulesTest.java @@ -3010,6 +3010,23 @@ private void transitiveInference(RelOptRule... extraRules) throws Exception { sql(sql).withPre(preProgram).with(program).check(); } + /** Test case for + * [CALCITE-2195] + * AggregateJoinTransposeRule fails to aggregate over unique column. */ + @Test public void testPushAggregateThroughJoin6() { + final HepProgram preProgram = new HepProgramBuilder() + .addRuleInstance(AggregateProjectMergeRule.INSTANCE) + .build(); + final HepProgram program = new HepProgramBuilder() + .addRuleInstance(AggregateJoinTransposeRule.EXTENDED) + .build(); + final String sql = "select sum(B.sal)\n" + + "from sales.emp as A\n" + + "join (select distinct sal from sales.emp) as B\n" + + "on A.sal=B.sal\n"; + sql(sql).withPre(preProgram).with(program).check(); + } + /** SUM is the easiest aggregate function to split. */ @Test public void testPushAggregateSumThroughJoin() { final HepProgram preProgram = new HepProgramBuilder() diff --git a/core/src/test/resources/org/apache/calcite/test/RelOptRulesTest.xml b/core/src/test/resources/org/apache/calcite/test/RelOptRulesTest.xml index 70e54c8550e..30162856ffa 100644 --- a/core/src/test/resources/org/apache/calcite/test/RelOptRulesTest.xml +++ b/core/src/test/resources/org/apache/calcite/test/RelOptRulesTest.xml @@ -5806,7 +5806,8 @@ LogicalProject(DEPTNO=[$0]) LogicalJoin(condition=[=($0, $1)], joinType=[inner]) LogicalAggregate(group=[{7}]) LogicalTableScan(table=[[CATALOG, SALES, EMP]]) - LogicalTableScan(table=[[CATALOG, SALES, DEPT]]) + LogicalProject(DEPTNO=[$0]) + LogicalTableScan(table=[[CATALOG, SALES, DEPT]]) ]]> @@ -5828,11 +5829,40 @@ LogicalProject(DEPTNO=[$0], DEPTNO0=[$1]) + + + + + + + + + + +