diff --git a/ql/src/java/org/apache/hadoop/hive/ql/optimizer/calcite/translator/HiveGBOpConvUtil.java b/ql/src/java/org/apache/hadoop/hive/ql/optimizer/calcite/translator/HiveGBOpConvUtil.java index 55f1247..1c43ea4 100644 --- a/ql/src/java/org/apache/hadoop/hive/ql/optimizer/calcite/translator/HiveGBOpConvUtil.java +++ b/ql/src/java/org/apache/hadoop/hive/ql/optimizer/calcite/translator/HiveGBOpConvUtil.java @@ -28,7 +28,6 @@ import org.apache.calcite.rel.RelNode; import org.apache.calcite.rel.core.AggregateCall; -import org.apache.calcite.rel.type.RelDataType; import org.apache.calcite.rex.RexInputRef; import org.apache.calcite.util.ImmutableBitSet; import org.apache.hadoop.hive.conf.HiveConf; @@ -42,7 +41,6 @@ import org.apache.hadoop.hive.ql.exec.Utilities; import org.apache.hadoop.hive.ql.io.AcidUtils; import org.apache.hadoop.hive.ql.metadata.VirtualColumn; -import org.apache.hadoop.hive.ql.optimizer.calcite.CalciteSemanticException; import org.apache.hadoop.hive.ql.optimizer.calcite.HiveCalciteUtil; import org.apache.hadoop.hive.ql.optimizer.calcite.reloperators.HiveAggregate; import org.apache.hadoop.hive.ql.optimizer.calcite.reloperators.HiveGroupingID; @@ -222,19 +220,23 @@ private static GBInfo getGBInfo(HiveAggregate aggRel, OpAttr inputOpAf, HiveConf } UDAFAttrs udafAttrs = new UDAFAttrs(); - udafAttrs.udafParams.addAll(HiveCalciteUtil.getExprNodes(aggCall.getArgList(), aggInputRel, - inputOpAf.tabAlias)); + List argExps = HiveCalciteUtil.getExprNodes(aggCall.getArgList(), aggInputRel, + inputOpAf.tabAlias); + udafAttrs.udafParams.addAll(argExps); udafAttrs.udafName = aggCall.getAggregation().getName(); udafAttrs.isDistinctUDAF = aggCall.isDistinct(); List argLst = new ArrayList(aggCall.getArgList()); List distColIndicesOfUDAF = new ArrayList(); List distUDAFParamsIndxInDistExprs = new ArrayList(); for (int i = 0; i < argLst.size(); i++) { - // NOTE: distinct expr can not be part of of GB key (we assume plan - // gen would have prevented it) + // NOTE: distinct expr can be part of of GB key if (udafAttrs.isDistinctUDAF) { - distColIndicesOfUDAF.add(distParamInRefsToOutputPos.get(argLst.get(i))); - distUDAFParamsIndxInDistExprs.add(distParamInRefsToOutputPos.get(argLst.get(i))); + ExprNodeDesc argExpr = argExps.get(i); + Integer found = findIn(gbInfo.gbKeys, argExpr); + distColIndicesOfUDAF.add(null == found ? distParamInRefsToOutputPos.get(argLst.get(i)) + gbInfo.gbKeys.size() + + (gbInfo.grpSets.size() > 0 ? 1 : 0) : found); + distUDAFParamsIndxInDistExprs.add(null == found ? distParamInRefsToOutputPos.get(argLst.get(i)) + gbInfo.gbKeys.size() + + (gbInfo.grpSets.size() > 0 ? 1 : 0) : found); } else { // TODO: this seems wrong (following what Hive Regular does) if (!distParamInRefsToOutputPos.containsKey(argLst.get(i)) @@ -270,6 +272,15 @@ private static GBInfo getGBInfo(HiveAggregate aggRel, OpAttr inputOpAf, HiveConf return gbInfo; } + private static Integer findIn (List exprs, ExprNodeDesc expr) { + for (int i = 0; i < exprs.size(); i++) { + if (expr.isSame(exprs.get(i))) { + return i; + } + } + return null; + } + static OpAttr translateGB(OpAttr inputOpAf, HiveAggregate aggRel, HiveConf hc) throws SemanticException { OpAttr translatedGBOpAttr = null; @@ -648,7 +659,6 @@ private static OpAttr genMapSideGBRS(OpAttr inputOpAf, GBInfo gbInfo) throws Sem List outputValueColumnNames = new ArrayList(); ArrayList colInfoLst = new ArrayList(); GroupByOperator mapGB = (GroupByOperator) inputOpAf.inputs.get(0); - int distColStartIndx = gbInfo.gbKeys.size() + (gbInfo.grpSets.size() > 0 ? 1 : 0); ArrayList reduceKeys = getReduceKeysForRS(mapGB, 0, gbInfo.gbKeys.size() - 1, outputKeyColumnNames, false, colInfoLst, colExprMap, false, false); @@ -667,10 +677,9 @@ private static OpAttr genMapSideGBRS(OpAttr inputOpAf, GBInfo gbInfo) throws Sem ArrayList reduceValues = getValueKeysForRS(mapGB, mapGB.getConf().getKeys() .size(), outputValueColumnNames, colInfoLst, colExprMap, false, false); - List> distinctColIndices = getDistColIndices(gbInfo, distColStartIndx); ReduceSinkOperator rsOp = (ReduceSinkOperator) OperatorFactory.getAndMakeChild(PlanUtils - .getReduceSinkDesc(reduceKeys, keyLength, reduceValues, distinctColIndices, + .getReduceSinkDesc(reduceKeys, keyLength, reduceValues, gbInfo.distColIndices, outputKeyColumnNames, outputValueColumnNames, true, -1, getNumPartFieldsForMapSideRS(gbInfo), getParallelismForMapSideRS(gbInfo), AcidUtils.Operation.NOT_ACID), new RowSchema(colInfoLst), mapGB); @@ -685,7 +694,6 @@ private static OpAttr genMapSideRS(OpAttr inputOpAf, GBInfo gbInfo) throws Seman List outputKeyColumnNames = new ArrayList(); List outputValueColumnNames = new ArrayList(); ArrayList colInfoLst = new ArrayList(); - int distColStartIndx = gbInfo.gbKeys.size() + (gbInfo.grpSets.size() > 0 ? 1 : 0); String outputColName; // 1. Add GB Keys to reduce keys @@ -725,7 +733,7 @@ private static OpAttr genMapSideRS(OpAttr inputOpAf, GBInfo gbInfo) throws Seman // 4. Gen RS ReduceSinkOperator rsOp = (ReduceSinkOperator) OperatorFactory.getAndMakeChild(PlanUtils .getReduceSinkDesc(reduceKeys, keyLength, reduceValues, - getDistColIndices(gbInfo, distColStartIndx), outputKeyColumnNames, + gbInfo.distColIndices, outputKeyColumnNames, outputValueColumnNames, true, -1, getNumPartFieldsForMapSideRS(gbInfo), getParallelismForMapSideRS(gbInfo), AcidUtils.Operation.NOT_ACID), new RowSchema( colInfoLst), inputOpAf.inputs.get(0)); @@ -1215,21 +1223,6 @@ private static void addGrpSetCol(boolean createConstantExpr, String grpSetIDExpr return valueKeys; } - private static List> getDistColIndices(GBInfo gbAttrs, int distOffSet) - throws SemanticException { - List> distColIndices = new ArrayList>(); - - for (List udafDistCols : gbAttrs.distColIndices) { - List udfAdjustedDistColIndx = new ArrayList(); - for (Integer distIndx : udafDistCols) { - udfAdjustedDistColIndx.add(distIndx + distOffSet); - } - distColIndices.add(udfAdjustedDistColIndx); - } - - return distColIndices; - } - // TODO: Implement this private static ExprNodeDesc propConstDistUDAFParams() { return null;