diff --git ql/src/java/org/apache/hadoop/hive/ql/optimizer/ConvertJoinMapJoin.java ql/src/java/org/apache/hadoop/hive/ql/optimizer/ConvertJoinMapJoin.java index c357329..41e0316 100644 --- ql/src/java/org/apache/hadoop/hive/ql/optimizer/ConvertJoinMapJoin.java +++ ql/src/java/org/apache/hadoop/hive/ql/optimizer/ConvertJoinMapJoin.java @@ -32,18 +32,22 @@ import org.apache.hadoop.hive.common.JavaUtils; import org.apache.hadoop.hive.conf.HiveConf; import org.apache.hadoop.hive.ql.exec.AppMasterEventOperator; +import org.apache.hadoop.hive.ql.exec.CommonJoinOperator; import org.apache.hadoop.hive.ql.exec.CommonMergeJoinOperator; import org.apache.hadoop.hive.ql.exec.DummyStoreOperator; import org.apache.hadoop.hive.ql.exec.FileSinkOperator; import org.apache.hadoop.hive.ql.exec.GroupByOperator; import org.apache.hadoop.hive.ql.exec.JoinOperator; +import org.apache.hadoop.hive.ql.exec.LateralViewJoinOperator; import org.apache.hadoop.hive.ql.exec.MapJoinOperator; import org.apache.hadoop.hive.ql.exec.MuxOperator; import org.apache.hadoop.hive.ql.exec.Operator; import org.apache.hadoop.hive.ql.exec.OperatorFactory; import org.apache.hadoop.hive.ql.exec.OperatorUtils; +import org.apache.hadoop.hive.ql.exec.PTFOperator; import org.apache.hadoop.hive.ql.exec.ReduceSinkOperator; import org.apache.hadoop.hive.ql.exec.TezDummyStoreOperator; +import org.apache.hadoop.hive.ql.exec.UDTFOperator; import org.apache.hadoop.hive.ql.lib.Node; import org.apache.hadoop.hive.ql.lib.NodeProcessor; import org.apache.hadoop.hive.ql.lib.NodeProcessorCtx; @@ -61,6 +65,8 @@ import org.apache.hadoop.hive.ql.plan.Statistics; import org.apache.hadoop.util.ReflectionUtils; +import com.google.common.collect.Lists; + /** * ConvertJoinMapJoin is an optimization that replaces a common join * (aka shuffle join) with a map join (aka broadcast or fragment replicate @@ -538,16 +544,20 @@ public int getMapJoinConversionPos(JoinOperator joinOp, OptimizeTezProcContext c HiveConf.ConfVars.HIVECONVERTJOINNOCONDITIONALTASKTHRESHOLD); int bigTablePosition = -1; - + // number of costly ops (Join, GB, PTF/Windowing, TF) below the big input + int bigInputNumberCostlyOps = -1; + // stats of the big input Statistics bigInputStat = null; - long totalSize = 0; - int pos = 0; // bigTableFound means we've encountered a table that's bigger than the // max. This table is either the the big table or we cannot convert. boolean bigTableFound = false; - for (Operator parentOp : joinOp.getParentOperators()) { + // total size of the inputs + long totalSize = 0; + + for (int pos = 0; pos < joinOp.getParentOperators().size(); pos++) { + Operator parentOp = joinOp.getParentOperators().get(pos); Statistics currInputStat = parentOp.getStatistics(); if (currInputStat == null) { @@ -555,9 +565,10 @@ public int getMapJoinConversionPos(JoinOperator joinOp, OptimizeTezProcContext c return -1; } + int inputNumberCostlyOps = getNumberOfCostlyOps(parentOp); long inputSize = currInputStat.getDataSize(); - if ((bigInputStat == null) - || ((bigInputStat != null) && (inputSize > bigInputStat.getDataSize()))) { + if (bigInputNumberCostlyOps == -1 || inputNumberCostlyOps > bigInputNumberCostlyOps + || (inputNumberCostlyOps == bigInputNumberCostlyOps && inputSize > bigInputStat.getDataSize())) { if (bigTableFound) { // cannot convert to map join; we've already chosen a big table @@ -589,6 +600,7 @@ public int getMapJoinConversionPos(JoinOperator joinOp, OptimizeTezProcContext c if (bigTableCandidateSet.contains(pos)) { bigTablePosition = pos; + bigInputNumberCostlyOps = inputNumberCostlyOps; bigInputStat = currInputStat; } } else { @@ -598,12 +610,39 @@ public int getMapJoinConversionPos(JoinOperator joinOp, OptimizeTezProcContext c return -1; } } - pos++; } return bigTablePosition; } + /* Count the number of costly ops below the input operator */ + private static int getNumberOfCostlyOps(Operator op) { + if (op.getParentOperators() == null || op.getParentOperators().size() == 0) { + return 0; + } + + int numberOps = 0; + List> ops = Lists.newArrayList(op.getParentOperators()); + while (!ops.isEmpty()) { + List> newOps = Lists.newArrayList(); + for (Operator parentOp : ops){ + if (parentOp instanceof CommonJoinOperator || + parentOp instanceof GroupByOperator || + parentOp instanceof LateralViewJoinOperator || + parentOp instanceof PTFOperator || + parentOp instanceof ReduceSinkOperator || + parentOp instanceof UDTFOperator) { + numberOps++; + } + if (parentOp.getParentOperators() != null && parentOp.getParentOperators().size() != 0) { + newOps.addAll(parentOp.getParentOperators()); + } + } + ops = newOps; + } + return numberOps; + } + /* * Once we have decided on the map join, the tree would transform from * @@ -616,7 +655,6 @@ public int getMapJoinConversionPos(JoinOperator joinOp, OptimizeTezProcContext c * * for tez. */ - public MapJoinOperator convertJoinMapJoin(JoinOperator joinOp, OptimizeTezProcContext context, int bigTablePosition, boolean removeReduceSink) throws SemanticException { // bail on mux operator because currently the mux operator masks the emit keys