Index: ql/src/java/org/apache/hadoop/hive/ql/optimizer/ColumnPrunerProcFactory.java =================================================================== --- ql/src/java/org/apache/hadoop/hive/ql/optimizer/ColumnPrunerProcFactory.java (revision 1152395) +++ ql/src/java/org/apache/hadoop/hive/ql/optimizer/ColumnPrunerProcFactory.java (working copy) @@ -505,8 +505,11 @@ } for (int i = 0; i < keys.size(); i++) { String outputCol = keys.get(i); + newOutputColNames.add(outputCol); String[] nm = parResover.reverseLookup(outputCol); ColumnInfo colInfo = oldRR.get(nm[0], nm[1]); + String internalName=colInfo.getInternalName(); + newMap.put(internalName, oldMap.get(internalName)); if (colInfo != null) { newRR.put(nm[0], nm[1], colInfo); } Index: ql/src/java/org/apache/hadoop/hive/ql/ppd/OpProcFactory.java =================================================================== --- ql/src/java/org/apache/hadoop/hive/ql/ppd/OpProcFactory.java (revision 1152395) +++ ql/src/java/org/apache/hadoop/hive/ql/ppd/OpProcFactory.java (working copy) @@ -19,17 +19,19 @@ import java.io.Serializable; import java.util.ArrayList; +import java.util.HashMap; import java.util.HashSet; import java.util.Iterator; import java.util.List; import java.util.Map; +import java.util.Map.Entry; import java.util.Set; import java.util.Stack; -import java.util.Map.Entry; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; import org.apache.hadoop.hive.conf.HiveConf; +import org.apache.hadoop.hive.ql.exec.ColumnInfo; import org.apache.hadoop.hive.ql.exec.FilterOperator; import org.apache.hadoop.hive.ql.exec.FunctionRegistry; import org.apache.hadoop.hive.ql.exec.JoinOperator; @@ -43,11 +45,13 @@ import org.apache.hadoop.hive.ql.lib.NodeProcessorCtx; import org.apache.hadoop.hive.ql.metadata.HiveStorageHandler; import org.apache.hadoop.hive.ql.metadata.HiveStoragePredicateHandler; -import org.apache.hadoop.hive.ql.metadata.HiveUtils; import org.apache.hadoop.hive.ql.metadata.Table; +import org.apache.hadoop.hive.ql.parse.ASTNode; +import org.apache.hadoop.hive.ql.parse.HiveParser; import org.apache.hadoop.hive.ql.parse.OpParseContext; import org.apache.hadoop.hive.ql.parse.RowResolver; import org.apache.hadoop.hive.ql.parse.SemanticException; +import org.apache.hadoop.hive.ql.plan.ExprNodeColumnDesc; import org.apache.hadoop.hive.ql.plan.ExprNodeDesc; import org.apache.hadoop.hive.ql.plan.ExprNodeGenericFuncDesc; import org.apache.hadoop.hive.ql.plan.FilterDesc; @@ -217,6 +221,64 @@ Set aliases = getQualifiedAliases((JoinOperator) nd, owi .getRowResolver(nd)); boolean hasUnpushedPredicates = mergeWithChildrenPred(nd, owi, null, aliases, false); + + // If we have a query like select * from invites join invites2 on + // invites.ds=invites2.ds where invites.ds='2011-01-01', then we want to + // recognize transitivity and push the filter invites2.ds='2011-01-01' + // down to invites2 as well. At this point, we only support + if (owi.getPrunedPreds((Operator) nd) != null) { + // We want to use the row resolvers of the children of the join op + Map aliasToRR = new HashMap(); + for (Operator o : ((JoinOperator) nd).getParentOperators()) { + for (String alias : owi.getRowResolver(o).getTableNames()){ + aliasToRR.put(alias, owi.getRowResolver(o)); + } + } + + // Populate lhs and rhs of the equijoin conditions + ArrayList lhs = owi.getParseContext().getJoinContext() + .get((JoinOperator) nd).getExpressions().get(0); + ArrayList rhs = owi.getParseContext().getJoinContext() + .get((JoinOperator) nd).getExpressions().get(1); + int numEqualities=owi.getParseContext().getJoinContext() + .get((JoinOperator) nd).getExpressions().get(0).size(); + + Map> oldFilters = owi.getPrunedPreds((Operator) nd).getFinalCandidates(); + Map> newFilters = new HashMap>(); + + // Replace columns in key in filter with columns in values + Map columnMap = new HashMap(); + for (int i=0; i e : columnMap.entrySet()) { + ColumnInfo left = e.getKey(); + ColumnInfo right = e.getValue(); + if (oldFilters.get(left.getTabAlias()) != null){ + for (ExprNodeDesc expr : oldFilters.get(left.getTabAlias())) { + Set colsreferenced = new HashSet(expr.getCols()); + if (colsreferenced.size() == 1 + && colsreferenced.contains(left.getInternalName())){ + ExprNodeDesc newexpr = expr.clone(); + replaceColumnReference(newexpr, left.getInternalName(), right.getInternalName()); + if (newFilters.get(right.getTabAlias()) == null) { + newFilters.put(right.getTabAlias(), new ArrayList()); + } + newFilters.get(right.getTabAlias()).add(newexpr); + } + } + } + } + + for (Entry> aliasToFilters : newFilters.entrySet()){ + owi.getPrunedPreds((Operator) nd).addPushDowns(aliasToFilters.getKey(), aliasToFilters.getValue()); + } + } + if (HiveConf.getBoolVar(owi.getParseContext().getConf(), HiveConf.ConfVars.HIVEPPDREMOVEDUPLICATEFILTERS)) { if (hasUnpushedPredicates) { @@ -227,7 +289,36 @@ } return null; } + private ColumnInfo getColumnInfoFromAST(ASTNode nd, Map aliastoRR) throws SemanticException { + if (nd.getType()==HiveParser.DOT) { + if (nd.getChildCount()==2) { + if (nd.getChild(0).getType()==HiveParser.TOK_TABLE_OR_COL + && nd.getChild(0).getChildCount()==1 + && nd.getChild(1).getType()==HiveParser.Identifier){ + String alias = nd.getChild(0).getChild(0).getText(); + String column = nd.getChild(1).getText(); + RowResolver rr=aliastoRR.get(alias); + return rr.get(alias, column); + } + } + } + return null; + } + private void replaceColumnReference(ExprNodeDesc expr, String oldcolumn, String newcolumn) { + if (expr instanceof ExprNodeColumnDesc) { + if (((ExprNodeColumnDesc) expr).getColumn().equals(oldcolumn)){ + ((ExprNodeColumnDesc) expr).setColumn(newcolumn); + } + } + + if (expr.getChildren() != null){ + for (ExprNodeDesc childexpr : expr.getChildren()) { + replaceColumnReference(childexpr, oldcolumn, newcolumn); + } + } + } + /** * Figures out the aliases for whom it is safe to push predicates based on * ANSI SQL semantics For inner join, all predicates for all aliases can be