diff --git a/ql/src/java/org/apache/hadoop/hive/ql/optimizer/optiq/translator/DerivedTableInjector.java b/ql/src/java/org/apache/hadoop/hive/ql/optimizer/optiq/translator/DerivedTableInjector.java index a655174..188d685 100644 --- a/ql/src/java/org/apache/hadoop/hive/ql/optimizer/optiq/translator/DerivedTableInjector.java +++ b/ql/src/java/org/apache/hadoop/hive/ql/optimizer/optiq/translator/DerivedTableInjector.java @@ -26,7 +26,10 @@ import org.apache.hadoop.hive.ql.optimizer.optiq.reloperators.HiveAggregateRel; import org.apache.hadoop.hive.ql.optimizer.optiq.reloperators.HiveProjectRel; import org.apache.hadoop.hive.ql.optimizer.optiq.reloperators.HiveSortRel; +import org.apache.hadoop.hive.serde2.typeinfo.TypeInfoFactory; +import org.eigenbase.rel.AggregateCall; import org.eigenbase.rel.AggregateRelBase; +import org.eigenbase.rel.Aggregation; import org.eigenbase.rel.EmptyRel; import org.eigenbase.rel.FilterRelBase; import org.eigenbase.rel.JoinRelBase; @@ -42,7 +45,12 @@ import org.eigenbase.rel.rules.MultiJoinRel; import org.eigenbase.relopt.hep.HepRelVertex; import org.eigenbase.relopt.volcano.RelSubset; +import org.eigenbase.reltype.RelDataType; +import org.eigenbase.reltype.RelDataTypeFactory; import org.eigenbase.rex.RexNode; +import org.eigenbase.sql.SqlKind; + +import com.google.common.collect.ImmutableList; public class DerivedTableInjector { @@ -99,8 +107,14 @@ private static void convertOpTree(RelNode rel, RelNode parent) { introduceDerivedTable(((HiveSortRel) rel).getChild(), rel); } } else if (rel instanceof HiveAggregateRel) { + RelNode newParent = parent; if (!validGBParent(rel, parent)) { - introduceDerivedTable(rel, parent); + newParent = introduceDerivedTable(rel, parent); + } + // check if groupby is empty and there is no other cols in aggr + // this should only happen when newParent is constant. + if (!checkEmptyGrpAggr(rel)) { + replaceEmptyGroupAggr(rel, newParent); } } } @@ -165,7 +179,7 @@ private static RelNode introduceDerivedTable(final RelNode rel) { return select; } - private static void introduceDerivedTable(final RelNode rel, RelNode parent) { + private static RelNode introduceDerivedTable(final RelNode rel, RelNode parent) { int i = 0; int pos = -1; List childList = parent.getInputs(); @@ -185,8 +199,59 @@ private static void introduceDerivedTable(final RelNode rel, RelNode parent) { RelNode select = introduceDerivedTable(rel); parent.replaceInput(pos, select); + + return select; } + private static void replaceEmptyGroupAggr(final RelNode rel, RelNode parent) { + // If this function is called, the parent should only include constant + List exps = parent.getChildExps(); + for (RexNode rexNode : exps) { + if (rexNode.getKind() != SqlKind.LITERAL) { + throw new RuntimeException("We expect " + parent.toString() + + " to contain only constants. However, " + rexNode.toString() + " is " + + rexNode.getKind()); + } + } + int i = 0; + int pos = -1; + List childList = parent.getInputs(); + for (RelNode child : childList) { + if (child == rel) { + pos = i; + break; + } + i++; + } + if (pos == -1) { + throw new RuntimeException("Couldn't find child node in parent's inputs"); + } + HiveAggregateRel oldAggRel = (HiveAggregateRel) rel; + RelDataTypeFactory typeFactory = oldAggRel.getCluster().getTypeFactory(); + RelDataType longType = TypeConverter.convert(TypeInfoFactory.longTypeInfo, typeFactory); + RelDataType intType = TypeConverter.convert(TypeInfoFactory.intTypeInfo, typeFactory); + // Create the dummy aggregation. + Aggregation countFn = (Aggregation) SqlFunctionConverter.getOptiqAggFn("count", + ImmutableList.of(intType), longType); + // TODO: Using 0 might be wrong; might need to walk down to find the + // proper index of a dummy. + List argList = ImmutableList.of(0); + AggregateCall dummyCall = new AggregateCall(countFn, false, argList, longType, null); + AggregateRelBase newAggRel = oldAggRel.copy(oldAggRel.getTraitSet(), oldAggRel.getChild(), + oldAggRel.getGroupSet(), ImmutableList.of(dummyCall)); + RelNode select = introduceDerivedTable(newAggRel); + parent.replaceInput(pos, select); + } + + private static boolean checkEmptyGrpAggr(RelNode gbNode) { + // Verify if both groupset and aggrfunction are empty) + AggregateRelBase aggrnode = (AggregateRelBase) gbNode; + if (aggrnode.getGroupSet().isEmpty() && aggrnode.getAggCallList().isEmpty()) { + return false; + } + return true; + } + private static boolean validJoinParent(RelNode joinNode, RelNode parent) { boolean validParent = true; diff --git a/ql/src/test/queries/clientpositive/cbo_correctness.q b/ql/src/test/queries/clientpositive/cbo_correctness.q index f7f0722..a65d88a 100644 --- a/ql/src/test/queries/clientpositive/cbo_correctness.q +++ b/ql/src/test/queries/clientpositive/cbo_correctness.q @@ -456,7 +456,29 @@ from (select b.key, count(*) ) a ; --- 17. get stats with empty partition list +-- 20. Test get stats with empty partition list select t1.value from t1 join t2 on t1.key = t2.key where t1.dt = '10' and t1.c_boolean = true; +-- 21. Test groupby is empty and there is no other cols in aggr +select unionsrc.key FROM (select 'tst1' as key, count(1) as value from src) unionsrc; + +select unionsrc.key, unionsrc.value FROM (select 'tst1' as key, count(1) as value from src) unionsrc; + +select unionsrc.key FROM (select 'max' as key, max(c_int) as value from t3 s1 + UNION ALL + select 'min' as key, min(c_int) as value from t3 s2 + UNION ALL + select 'avg' as key, avg(c_int) as value from t3 s3) unionsrc order by unionsrc.key; + +select unionsrc.key, unionsrc.value FROM (select 'max' as key, max(c_int) as value from t3 s1 + UNION ALL + select 'min' as key, min(c_int) as value from t3 s2 + UNION ALL + select 'avg' as key, avg(c_int) as value from t3 s3) unionsrc order by unionsrc.key; + +select unionsrc.key, count(1) FROM (select 'max' as key, max(c_int) as value from t3 s1 + UNION ALL + select 'min' as key, min(c_int) as value from t3 s2 + UNION ALL + select 'avg' as key, avg(c_int) as value from t3 s3) unionsrc group by unionsrc.key order by unionsrc.key; diff --git a/ql/src/test/results/clientpositive/cbo_correctness.q.out b/ql/src/test/results/clientpositive/cbo_correctness.q.out index 3335d4d..f3f94fd 100644 --- a/ql/src/test/results/clientpositive/cbo_correctness.q.out +++ b/ql/src/test/results/clientpositive/cbo_correctness.q.out @@ -18946,17 +18946,94 @@ POSTHOOK: Input: default@src_cbo 96 1 97 2 98 2 -PREHOOK: query: -- 17. get stats with empty partition list +PREHOOK: query: -- 20. Test get stats with empty partition list select t1.value from t1 join t2 on t1.key = t2.key where t1.dt = '10' and t1.c_boolean = true PREHOOK: type: QUERY PREHOOK: Input: default@t1 PREHOOK: Input: default@t2 PREHOOK: Input: default@t2@dt=2014 #### A masked pattern was here #### -POSTHOOK: query: -- 17. get stats with empty partition list +POSTHOOK: query: -- 20. Test get stats with empty partition list select t1.value from t1 join t2 on t1.key = t2.key where t1.dt = '10' and t1.c_boolean = true POSTHOOK: type: QUERY POSTHOOK: Input: default@t1 POSTHOOK: Input: default@t2 POSTHOOK: Input: default@t2@dt=2014 #### A masked pattern was here #### +PREHOOK: query: -- 21. Test groupby is empty and there is no other cols in aggr +select unionsrc.key FROM (select 'tst1' as key, count(1) as value from src) unionsrc +PREHOOK: type: QUERY +PREHOOK: Input: default@src +#### A masked pattern was here #### +POSTHOOK: query: -- 21. Test groupby is empty and there is no other cols in aggr +select unionsrc.key FROM (select 'tst1' as key, count(1) as value from src) unionsrc +POSTHOOK: type: QUERY +POSTHOOK: Input: default@src +#### A masked pattern was here #### +tst1 +PREHOOK: query: select unionsrc.key, unionsrc.value FROM (select 'tst1' as key, count(1) as value from src) unionsrc +PREHOOK: type: QUERY +PREHOOK: Input: default@src +#### A masked pattern was here #### +POSTHOOK: query: select unionsrc.key, unionsrc.value FROM (select 'tst1' as key, count(1) as value from src) unionsrc +POSTHOOK: type: QUERY +POSTHOOK: Input: default@src +#### A masked pattern was here #### +tst1 500 +PREHOOK: query: select unionsrc.key FROM (select 'max' as key, max(c_int) as value from t3 s1 + UNION ALL + select 'min' as key, min(c_int) as value from t3 s2 + UNION ALL + select 'avg' as key, avg(c_int) as value from t3 s3) unionsrc order by unionsrc.key +PREHOOK: type: QUERY +PREHOOK: Input: default@t3 +#### A masked pattern was here #### +POSTHOOK: query: select unionsrc.key FROM (select 'max' as key, max(c_int) as value from t3 s1 + UNION ALL + select 'min' as key, min(c_int) as value from t3 s2 + UNION ALL + select 'avg' as key, avg(c_int) as value from t3 s3) unionsrc order by unionsrc.key +POSTHOOK: type: QUERY +POSTHOOK: Input: default@t3 +#### A masked pattern was here #### +avg +max +min +PREHOOK: query: select unionsrc.key, unionsrc.value FROM (select 'max' as key, max(c_int) as value from t3 s1 + UNION ALL + select 'min' as key, min(c_int) as value from t3 s2 + UNION ALL + select 'avg' as key, avg(c_int) as value from t3 s3) unionsrc order by unionsrc.key +PREHOOK: type: QUERY +PREHOOK: Input: default@t3 +#### A masked pattern was here #### +POSTHOOK: query: select unionsrc.key, unionsrc.value FROM (select 'max' as key, max(c_int) as value from t3 s1 + UNION ALL + select 'min' as key, min(c_int) as value from t3 s2 + UNION ALL + select 'avg' as key, avg(c_int) as value from t3 s3) unionsrc order by unionsrc.key +POSTHOOK: type: QUERY +POSTHOOK: Input: default@t3 +#### A masked pattern was here #### +avg 1.5 +max 3.0 +min 1.0 +PREHOOK: query: select unionsrc.key, count(1) FROM (select 'max' as key, max(c_int) as value from t3 s1 + UNION ALL + select 'min' as key, min(c_int) as value from t3 s2 + UNION ALL + select 'avg' as key, avg(c_int) as value from t3 s3) unionsrc group by unionsrc.key order by unionsrc.key +PREHOOK: type: QUERY +PREHOOK: Input: default@t3 +#### A masked pattern was here #### +POSTHOOK: query: select unionsrc.key, count(1) FROM (select 'max' as key, max(c_int) as value from t3 s1 + UNION ALL + select 'min' as key, min(c_int) as value from t3 s2 + UNION ALL + select 'avg' as key, avg(c_int) as value from t3 s3) unionsrc group by unionsrc.key order by unionsrc.key +POSTHOOK: type: QUERY +POSTHOOK: Input: default@t3 +#### A masked pattern was here #### +avg 1 +max 1 +min 1