diff --git a/ql/src/java/org/apache/hadoop/hive/ql/optimizer/calcite/rules/HiveRelFieldTrimmer.java b/ql/src/java/org/apache/hadoop/hive/ql/optimizer/calcite/rules/HiveRelFieldTrimmer.java index 2bfd12a3f5..c5d1bb1257 100644 --- a/ql/src/java/org/apache/hadoop/hive/ql/optimizer/calcite/rules/HiveRelFieldTrimmer.java +++ b/ql/src/java/org/apache/hadoop/hive/ql/optimizer/calcite/rules/HiveRelFieldTrimmer.java @@ -19,6 +19,8 @@ import java.util.ArrayList; import java.util.Collections; +import java.util.HashMap; +import java.util.Iterator; import java.util.LinkedHashSet; import java.util.List; import java.util.Map; @@ -46,6 +48,7 @@ import org.apache.calcite.rex.RexLiteral; import org.apache.calcite.rex.RexNode; import org.apache.calcite.rex.RexPermuteInputsShuttle; +import org.apache.calcite.rex.RexTableInputRef; import org.apache.calcite.rex.RexVisitor; import org.apache.calcite.sql.SqlKind; import org.apache.calcite.sql.validate.SqlValidator; @@ -320,43 +323,79 @@ private boolean isRexLiteral(final RexNode rexNode) { // if those are columns are not being used further up private ImmutableBitSet generateGroupSetIfCardinalitySame(final Aggregate aggregate, final ImmutableBitSet originalGroupSet, final ImmutableBitSet fieldsUsed) { - Pair> tabToOrgCol = HiveRelOptUtil.getColumnOriginSet(aggregate.getInput(), - originalGroupSet); - if(tabToOrgCol == null) { - return originalGroupSet; - } - RelOptHiveTable tbl = (RelOptHiveTable)tabToOrgCol.left; - List backtrackedGBList = tabToOrgCol.right; - ImmutableBitSet backtrackedGBSet = ImmutableBitSet.builder().addAll(backtrackedGBList).build(); - List allKeys = tbl.getNonNullableKeys(); - ImmutableBitSet currentKey = null; - for(ImmutableBitSet key:allKeys) { - if(backtrackedGBSet.contains(key)) { - // only if grouping sets consist of keys - currentKey = key; - break; + RexBuilder rexBuilder = aggregate.getCluster().getRexBuilder(); + RelMetadataQuery mq = aggregate.getCluster().getMetadataQuery(); + + Iterator iterator = originalGroupSet.iterator(); + Map, Pair, List>> mapGBKeysLineage= new HashMap<>(); + + while(iterator.hasNext()) { + Integer key = iterator.next(); + RexNode inputRef = rexBuilder.makeInputRef(aggregate.getInput(), key.intValue()); + Set exprLineage = mq.getExpressionLineage(aggregate, inputRef); + if(exprLineage != null && exprLineage.size() == 1){ + RexNode expr = exprLineage.iterator().next(); + if(expr instanceof RexTableInputRef) { + //TODO: what if expression + RexTableInputRef tblRef = (RexTableInputRef)expr; + Pair baseTable = Pair.of(tblRef.getTableRef().getTable(), tblRef.getTableRef().getEntityNumber()); + if(mapGBKeysLineage.containsKey(baseTable)) { + List baseCol = mapGBKeysLineage.get(baseTable).left; + baseCol.add(tblRef.getIndex()); + List gbKey = mapGBKeysLineage.get(baseTable).right; + gbKey.add(key); + } else { + List baseCol = new ArrayList<>(); + baseCol.add(tblRef.getIndex()); + List gbKey = new ArrayList<>(); + gbKey.add(key); + mapGBKeysLineage.put(baseTable, Pair.of(baseCol, gbKey)); + } + } } } - if(currentKey == null || currentKey.isEmpty()) { - return originalGroupSet; - } // we want to delete all columns in original GB set except the key ImmutableBitSet.Builder builder = ImmutableBitSet.builder(); - // we have established that this gb set contains keys and it is safe to remove rest of the columns - for(int i=0; i, Pair, List>> entry:mapGBKeysLineage.entrySet()) { + RelOptHiveTable tbl = (RelOptHiveTable)entry.getKey().left; + List backtrackedGBList = entry.getValue().left; + List gbKeys = entry.getValue().right; + + ImmutableBitSet backtrackedGBSet = ImmutableBitSet.builder().addAll(backtrackedGBList).build(); + + List allKeys = tbl.getNonNullableKeys(); + ImmutableBitSet currentKey = null; + for(ImmutableBitSet key:allKeys) { + if(backtrackedGBSet.contains(key)) { + // only if grouping sets consist of keys + currentKey = key; + break; + } + } + if(currentKey == null || currentKey.isEmpty()) { + continue; + } + + + // we have established that this gb set contains keys and it is safe to remove rest of the columns + for(int i=0; i4) subq; + create table web_sales(ws_order_number int, ws_item_sk int, ws_price float, constraint pk1 primary key(ws_order_number, ws_item_sk) disable rely); insert into web_sales values(1, 1, 1.2); @@ -435,3 +446,14 @@ insert into web_sales values(1, 1, 1.2); explain cbo select count(distinct ws_order_number) from web_sales; select count(distinct ws_order_number) from web_sales; drop table web_sales; + +-- UNION +create table t1(i int primary key disable rely, j int); +insert into t1 values(1,100),(2,200); +create table t2(i int primary key disable rely, j int); +insert into t2 values(2,1000),(4,500); + +explain cbo select i from (select i, j from t1 union all select i,j from t2) subq group by i,j; +select i from (select i, j from t1 union all select i,j from t2) subq group by i,j; +drop table t1; +drop table t2; \ No newline at end of file diff --git a/ql/src/test/results/clientpositive/llap/constraints_optimization.q.out b/ql/src/test/results/clientpositive/llap/constraints_optimization.q.out index f7ed9f58a8..1e700ce9ee 100644 --- a/ql/src/test/results/clientpositive/llap/constraints_optimization.q.out +++ b/ql/src/test/results/clientpositive/llap/constraints_optimization.q.out @@ -2742,6 +2742,50 @@ HiveAggregate(group=[{1}]) HiveFilter(condition=[IS NOT NULL($3)]) HiveTableScan(table=[[default, store_sales]], table:alias=[store_sales]) +PREHOOK: query: explain cbo select c_customer_sk from + (select c_first_name, c_customer_sk ,d_date solddate,count(*) cnt + from store_sales + ,date_dim + ,customer + where ss_sold_date_sk = d_date_sk + and ss_item_sk = c_customer_sk + group by c_first_name,c_customer_sk,d_date + having count(*) >4) subq +PREHOOK: type: QUERY +PREHOOK: Input: default@customer +PREHOOK: Input: default@date_dim +PREHOOK: Input: default@store_sales +#### A masked pattern was here #### +POSTHOOK: query: explain cbo select c_customer_sk from + (select c_first_name, c_customer_sk ,d_date solddate,count(*) cnt + from store_sales + ,date_dim + ,customer + where ss_sold_date_sk = d_date_sk + and ss_item_sk = c_customer_sk + group by c_first_name,c_customer_sk,d_date + having count(*) >4) subq +POSTHOOK: type: QUERY +POSTHOOK: Input: default@customer +POSTHOOK: Input: default@date_dim +POSTHOOK: Input: default@store_sales +#### A masked pattern was here #### +CBO PLAN: +HiveProject(c_customer_sk=[$0]) + HiveFilter(condition=[>($2, 4)]) + HiveProject(c_customer_sk=[$1], d_date=[$0], $f2=[$2]) + HiveAggregate(group=[{3, 4}], agg#0=[count()]) + HiveJoin(condition=[=($1, $4)], joinType=[inner], algorithm=[none], cost=[not available]) + HiveJoin(condition=[=($0, $2)], joinType=[inner], algorithm=[none], cost=[not available]) + HiveProject(ss_sold_date_sk=[$0], ss_item_sk=[$2]) + HiveFilter(condition=[IS NOT NULL($0)]) + HiveTableScan(table=[[default, store_sales]], table:alias=[store_sales]) + HiveProject(d_date_sk=[$0], d_date=[$2]) + HiveFilter(condition=[IS NOT NULL($0)]) + HiveTableScan(table=[[default, date_dim]], table:alias=[date_dim]) + HiveProject(c_customer_sk=[$0], c_first_name=[$8]) + HiveTableScan(table=[[default, customer]], table:alias=[customer]) + PREHOOK: query: create table web_sales(ws_order_number int, ws_item_sk int, ws_price float, constraint pk1 primary key(ws_order_number, ws_item_sk) disable rely) PREHOOK: type: CREATETABLE @@ -2805,3 +2849,89 @@ POSTHOOK: query: drop table web_sales POSTHOOK: type: DROPTABLE POSTHOOK: Input: default@web_sales POSTHOOK: Output: default@web_sales +PREHOOK: query: create table t1(i int primary key disable rely, j int) +PREHOOK: type: CREATETABLE +PREHOOK: Output: database:default +PREHOOK: Output: default@t1 +POSTHOOK: query: create table t1(i int primary key disable rely, j int) +POSTHOOK: type: CREATETABLE +POSTHOOK: Output: database:default +POSTHOOK: Output: default@t1 +PREHOOK: query: insert into t1 values(1,100),(2,200) +PREHOOK: type: QUERY +PREHOOK: Input: _dummy_database@_dummy_table +PREHOOK: Output: default@t1 +POSTHOOK: query: insert into t1 values(1,100),(2,200) +POSTHOOK: type: QUERY +POSTHOOK: Input: _dummy_database@_dummy_table +POSTHOOK: Output: default@t1 +POSTHOOK: Lineage: t1.i SCRIPT [] +POSTHOOK: Lineage: t1.j SCRIPT [] +PREHOOK: query: create table t2(i int primary key disable rely, j int) +PREHOOK: type: CREATETABLE +PREHOOK: Output: database:default +PREHOOK: Output: default@t2 +POSTHOOK: query: create table t2(i int primary key disable rely, j int) +POSTHOOK: type: CREATETABLE +POSTHOOK: Output: database:default +POSTHOOK: Output: default@t2 +PREHOOK: query: insert into t2 values(2,1000),(4,500) +PREHOOK: type: QUERY +PREHOOK: Input: _dummy_database@_dummy_table +PREHOOK: Output: default@t2 +POSTHOOK: query: insert into t2 values(2,1000),(4,500) +POSTHOOK: type: QUERY +POSTHOOK: Input: _dummy_database@_dummy_table +POSTHOOK: Output: default@t2 +POSTHOOK: Lineage: t2.i SCRIPT [] +POSTHOOK: Lineage: t2.j SCRIPT [] +PREHOOK: query: explain cbo select i from (select i, j from t1 union all select i,j from t2) subq group by i,j +PREHOOK: type: QUERY +PREHOOK: Input: default@t1 +PREHOOK: Input: default@t2 +#### A masked pattern was here #### +POSTHOOK: query: explain cbo select i from (select i, j from t1 union all select i,j from t2) subq group by i,j +POSTHOOK: type: QUERY +POSTHOOK: Input: default@t1 +POSTHOOK: Input: default@t2 +#### A masked pattern was here #### +CBO PLAN: +HiveProject(i=[$0]) + HiveAggregate(group=[{0, 1}]) + HiveProject(i=[$0], j=[$1]) + HiveUnion(all=[true]) + HiveProject(i=[$0], j=[$1]) + HiveTableScan(table=[[default, t1]], table:alias=[t1]) + HiveProject(i=[$0], j=[$1]) + HiveTableScan(table=[[default, t2]], table:alias=[t2]) + +PREHOOK: query: select i from (select i, j from t1 union all select i,j from t2) subq group by i,j +PREHOOK: type: QUERY +PREHOOK: Input: default@t1 +PREHOOK: Input: default@t2 +#### A masked pattern was here #### +POSTHOOK: query: select i from (select i, j from t1 union all select i,j from t2) subq group by i,j +POSTHOOK: type: QUERY +POSTHOOK: Input: default@t1 +POSTHOOK: Input: default@t2 +#### A masked pattern was here #### +2 +2 +4 +1 +PREHOOK: query: drop table t1 +PREHOOK: type: DROPTABLE +PREHOOK: Input: default@t1 +PREHOOK: Output: default@t1 +POSTHOOK: query: drop table t1 +POSTHOOK: type: DROPTABLE +POSTHOOK: Input: default@t1 +POSTHOOK: Output: default@t1 +PREHOOK: query: drop table t2 +PREHOOK: type: DROPTABLE +PREHOOK: Input: default@t2 +PREHOOK: Output: default@t2 +POSTHOOK: query: drop table t2 +POSTHOOK: type: DROPTABLE +POSTHOOK: Input: default@t2 +POSTHOOK: Output: default@t2