diff --git a/ql/src/java/org/apache/hadoop/hive/ql/exec/Utilities.java b/ql/src/java/org/apache/hadoop/hive/ql/exec/Utilities.java index fce11c8..0a1f06d 100644 --- a/ql/src/java/org/apache/hadoop/hive/ql/exec/Utilities.java +++ b/ql/src/java/org/apache/hadoop/hive/ql/exec/Utilities.java @@ -2398,12 +2398,11 @@ public static String join(String... elements) { return builder.toString(); } - public static void setColumnNameList(JobConf jobConf, Operator op) { - setColumnNameList(jobConf, op, false); + public static void setColumnNameList(JobConf jobConf, RowSchema rowSchema) { + setColumnNameList(jobConf, rowSchema, false); } - public static void setColumnNameList(JobConf jobConf, Operator op, boolean excludeVCs) { - RowSchema rowSchema = op.getSchema(); + public static void setColumnNameList(JobConf jobConf, RowSchema rowSchema, boolean excludeVCs) { if (rowSchema == null) { return; } @@ -2421,12 +2420,20 @@ public static void setColumnNameList(JobConf jobConf, Operator op, boolean exclu jobConf.set(serdeConstants.LIST_COLUMNS, columnNamesString); } - public static void setColumnTypeList(JobConf jobConf, Operator op) { - setColumnTypeList(jobConf, op, false); + public static void setColumnNameList(JobConf jobConf, Operator op) { + setColumnNameList(jobConf, op, false); } - public static void setColumnTypeList(JobConf jobConf, Operator op, boolean excludeVCs) { + public static void setColumnNameList(JobConf jobConf, Operator op, boolean excludeVCs) { RowSchema rowSchema = op.getSchema(); + setColumnNameList(jobConf, rowSchema, excludeVCs); + } + + public static void setColumnTypeList(JobConf jobConf, RowSchema rowSchema) { + setColumnTypeList(jobConf, rowSchema, false); + } + + public static void setColumnTypeList(JobConf jobConf, RowSchema rowSchema, boolean excludeVCs) { if (rowSchema == null) { return; } @@ -2444,6 +2451,15 @@ public static void setColumnTypeList(JobConf jobConf, Operator op, boolean exclu jobConf.set(serdeConstants.LIST_COLUMN_TYPES, columnTypesString); } + public static void setColumnTypeList(JobConf jobConf, Operator op) { + setColumnTypeList(jobConf, op, false); + } + + public static void setColumnTypeList(JobConf jobConf, Operator op, boolean excludeVCs) { + RowSchema rowSchema = op.getSchema(); + setColumnTypeList(jobConf, rowSchema, excludeVCs); + } + public static String suffix = ".hashtable"; public static Path generatePath(Path basePath, String dumpFilePrefix, diff --git a/ql/src/java/org/apache/hadoop/hive/ql/io/HiveInputFormat.java b/ql/src/java/org/apache/hadoop/hive/ql/io/HiveInputFormat.java index 1f262d0..51c80e4 100755 --- a/ql/src/java/org/apache/hadoop/hive/ql/io/HiveInputFormat.java +++ b/ql/src/java/org/apache/hadoop/hive/ql/io/HiveInputFormat.java @@ -23,6 +23,8 @@ import java.io.IOException; import java.io.Serializable; import java.util.ArrayList; +import java.util.Arrays; +import java.util.HashSet; import java.util.Iterator; import java.util.List; import java.util.Map; @@ -44,16 +46,20 @@ import org.apache.hadoop.hive.ql.exec.spark.SparkDynamicPartitionPruner; import org.apache.hadoop.hive.ql.plan.TableDesc; import org.apache.hadoop.hive.ql.exec.Operator; +import org.apache.hadoop.hive.ql.exec.RowSchema; import org.apache.hadoop.hive.ql.exec.TableScanOperator; +import org.apache.hadoop.hive.ql.exec.UDFArgumentException; import org.apache.hadoop.hive.ql.exec.Utilities; import org.apache.hadoop.hive.ql.exec.vector.VectorizedRowBatch; import org.apache.hadoop.hive.ql.log.PerfLogger; +import org.apache.hadoop.hive.ql.plan.ExprNodeDesc; import org.apache.hadoop.hive.ql.plan.ExprNodeGenericFuncDesc; import org.apache.hadoop.hive.ql.plan.MapWork; import org.apache.hadoop.hive.ql.plan.OperatorDesc; import org.apache.hadoop.hive.ql.plan.PartitionDesc; import org.apache.hadoop.hive.ql.plan.TableScanDesc; import org.apache.hadoop.hive.ql.session.SessionState; +import org.apache.hadoop.hive.ql.udf.generic.GenericUDFOPOr; import org.apache.hadoop.hive.serde2.ColumnProjectionUtils; import org.apache.hadoop.io.Writable; import org.apache.hadoop.io.WritableComparable; @@ -532,34 +538,21 @@ protected static PartitionDesc getPartitionDescFromPath( return partDesc; } + /** + * Push filter for single TS operator + */ public static void pushFilters(JobConf jobConf, TableScanOperator tableScan) { - - // ensure filters are not set from previous pushFilters - jobConf.unset(TableScanDesc.FILTER_TEXT_CONF_STR); - jobConf.unset(TableScanDesc.FILTER_EXPR_CONF_STR); - - Utilities.unsetSchemaEvolution(jobConf); - TableScanDesc scanDesc = tableScan.getConf(); if (scanDesc == null) { return; } + Utilities.unsetSchemaEvolution(jobConf); Utilities.addTableSchemaToConf(jobConf, tableScan); - // construct column name list and types for reference by filter push down - Utilities.setColumnNameList(jobConf, tableScan); - Utilities.setColumnTypeList(jobConf, tableScan); - // push down filters - ExprNodeGenericFuncDesc filterExpr = (ExprNodeGenericFuncDesc)scanDesc.getFilterExpr(); - if (filterExpr == null) { - return; - } - String serializedFilterObj = scanDesc.getSerializedFilterObject(); - String serializedFilterExpr = scanDesc.getSerializedFilterExpr(); - boolean hasObj = serializedFilterObj != null, hasExpr = serializedFilterExpr != null; - if (!hasObj) { + boolean hasObj = serializedFilterObj != null; + if (!hasObj) { Serializable filterObject = scanDesc.getFilterObject(); if (filterObject != null) { serializedFilterObj = SerializationUtilities.serializeObject(filterObject); @@ -568,17 +561,46 @@ public static void pushFilters(JobConf jobConf, TableScanOperator tableScan) { if (serializedFilterObj != null) { jobConf.set(TableScanDesc.FILTER_OBJECT_CONF_STR, serializedFilterObj); } + + if (LOG.isDebugEnabled()) { + LOG.debug("Pushdown initiated with " + (serializedFilterObj == null ? "" : + (", serializedFilterObj = " + serializedFilterObj + " (" + (hasObj ? "desc" : "new") + + ")"))); + } + + pushFilters(jobConf, tableScan.getSchema(), scanDesc.getFilterExpr(), scanDesc.getSerializedFilterExpr()); +} + + /** + * Push filter for multiple combined TS operators + */ + private static void pushFilters(JobConf jobConf, + RowSchema rowSchema, + ExprNodeGenericFuncDesc filterExpr, + String serializedFilterExpr) { + // ensure filters are not set from previous pushFilters + jobConf.unset(TableScanDesc.FILTER_TEXT_CONF_STR); + jobConf.unset(TableScanDesc.FILTER_EXPR_CONF_STR); + + // construct column name list and types for reference by filter push down + Utilities.setColumnNameList(jobConf, rowSchema); + Utilities.setColumnTypeList(jobConf, rowSchema); + // push down filters + if (filterExpr == null) { + return; + } + + boolean hasExpr = serializedFilterExpr != null; + String filterText = filterExpr.getExprString(); if (!hasExpr) { serializedFilterExpr = SerializationUtilities.serializeExpression(filterExpr); } - String filterText = filterExpr.getExprString(); if (LOG.isDebugEnabled()) { LOG.debug("Pushdown initiated with filterText = " + filterText + ", filterExpr = " + filterExpr + ", serializedFilterExpr = " + serializedFilterExpr + " (" - + (hasExpr ? "desc" : "new") + ")" + (serializedFilterObj == null ? "" : - (", serializedFilterObj = " + serializedFilterObj + " (" + (hasObj ? "desc" : "new") - + ")"))); + + (hasExpr ? "desc" : "new") + ")"); } + jobConf.set(TableScanDesc.FILTER_TEXT_CONF_STR, filterText); jobConf.set(TableScanDesc.FILTER_EXPR_CONF_STR, serializedFilterExpr); } @@ -589,8 +611,11 @@ protected void pushProjectionsAndFilters(JobConf jobConf, Class inputFormatClass splitPathWithNoSchema, false); } - protected void pushProjectionsAndFilters(JobConf jobConf, Class inputFormatClass, - String splitPath, String splitPathWithNoSchema, boolean nonNative) { + protected void pushProjectionsAndFilters(JobConf jobConf, + Class inputFormatClass, + String splitPath, + String splitPathWithNoSchema, + boolean nonNative) { if (this.mrwork == null) { init(job); } @@ -599,7 +624,7 @@ protected void pushProjectionsAndFilters(JobConf jobConf, Class inputFormatClass return; } - ArrayList aliases = new ArrayList(); + Set aliases = new HashSet(); Iterator>> iterator = this.mrwork .getPathToAliases().entrySet().iterator(); @@ -622,24 +647,61 @@ protected void pushProjectionsAndFilters(JobConf jobConf, Class inputFormatClass splitPath.startsWith(key) || splitPathWithNoSchema.startsWith(key); } if (match) { - ArrayList list = entry.getValue(); - for (String val : list) { - aliases.add(val); - } + aliases.addAll(entry.getValue()); } } - for (String alias : aliases) { - Operator op = this.mrwork.getAliasToWork().get( - alias); + // Collect the needed columns from all the aliases and create ORed filter + // expression for the table. + boolean allColumnsNeeded = false; + boolean noFilters = false; + Set neededColumnIDs = new HashSet(); + List filterExprs = new ArrayList(); + RowSchema rowSchema = null; + + for(String alias : aliases) { + Operator op = mrwork.getAliasToWork().get(alias); if (op instanceof TableScanOperator) { TableScanOperator ts = (TableScanOperator) op; - // push down projections. - ColumnProjectionUtils.appendReadColumns( - jobConf, ts.getNeededColumnIDs(), ts.getNeededColumns()); - // push down filters - pushFilters(jobConf, ts); + + if (ts.getNeededColumnIDs() == null) { + allColumnsNeeded = true; + } else { + neededColumnIDs.addAll(ts.getNeededColumnIDs()); + } + + rowSchema = ts.getSchema(); + ExprNodeGenericFuncDesc filterExpr = + ts.getConf() == null ? null : ts.getConf().getFilterExpr(); + noFilters = filterExpr == null; // No filter if any TS has no filter expression + filterExprs.add(filterExpr); + } + } + + ExprNodeGenericFuncDesc tableFilterExpr = null; + if (!noFilters) { + try { + for (ExprNodeGenericFuncDesc filterExpr : filterExprs) { + if (tableFilterExpr == null ) { + tableFilterExpr = filterExpr; + } else { + tableFilterExpr = ExprNodeGenericFuncDesc.newInstance(new GenericUDFOPOr(), + Arrays.asList(tableFilterExpr, filterExpr)); + } + } + } catch(UDFArgumentException ex) { + LOG.debug("Turn off filtering due to " + ex); + tableFilterExpr = null; } } + + // push down projections + if (!allColumnsNeeded) { + ColumnProjectionUtils.appendReadColumnIDs(jobConf, new ArrayList(neededColumnIDs)); + } else { + ColumnProjectionUtils.setFullyReadColumns(jobConf); + } + + pushFilters(jobConf, rowSchema, tableFilterExpr, null); } } diff --git a/ql/src/java/org/apache/hadoop/hive/ql/io/parquet/ProjectionPusher.java b/ql/src/java/org/apache/hadoop/hive/ql/io/parquet/ProjectionPusher.java index 017676b..77b83b0 100644 --- a/ql/src/java/org/apache/hadoop/hive/ql/io/parquet/ProjectionPusher.java +++ b/ql/src/java/org/apache/hadoop/hive/ql/io/parquet/ProjectionPusher.java @@ -16,11 +16,14 @@ import java.io.IOException; import java.io.Serializable; import java.util.ArrayList; +import java.util.Arrays; +import java.util.HashSet; import java.util.Iterator; import java.util.LinkedHashMap; import java.util.List; import java.util.Map; import java.util.Map.Entry; +import java.util.Set; import org.apache.hadoop.hive.ql.exec.SerializationUtilities; import org.slf4j.Logger; @@ -28,12 +31,16 @@ import org.apache.hadoop.fs.Path; import org.apache.hadoop.hive.conf.HiveConf; import org.apache.hadoop.hive.ql.exec.Operator; +import org.apache.hadoop.hive.ql.exec.RowSchema; import org.apache.hadoop.hive.ql.exec.TableScanOperator; +import org.apache.hadoop.hive.ql.exec.UDFArgumentException; import org.apache.hadoop.hive.ql.exec.Utilities; +import org.apache.hadoop.hive.ql.plan.ExprNodeDesc; import org.apache.hadoop.hive.ql.plan.ExprNodeGenericFuncDesc; import org.apache.hadoop.hive.ql.plan.MapWork; import org.apache.hadoop.hive.ql.plan.PartitionDesc; import org.apache.hadoop.hive.ql.plan.TableScanDesc; +import org.apache.hadoop.hive.ql.udf.generic.GenericUDFOPOr; import org.apache.hadoop.hive.serde2.ColumnProjectionUtils; import org.apache.hadoop.mapred.JobConf; @@ -68,7 +75,8 @@ private void updateMrWork(final JobConf job) { @Deprecated // Uses deprecated methods on ColumnProjectionUtils private void pushProjectionsAndFilters(final JobConf jobConf, - final String splitPath, final String splitPathWithNoSchema) { + final String splitPath, + final String splitPathWithNoSchema) { if (mapWork == null) { return; @@ -76,53 +84,78 @@ private void pushProjectionsAndFilters(final JobConf jobConf, return; } - final ArrayList aliases = new ArrayList(); - final Iterator>> iterator = mapWork.getPathToAliases().entrySet().iterator(); + final Set aliases = new HashSet(); + final Iterator>> iterator = + mapWork.getPathToAliases().entrySet().iterator(); while (iterator.hasNext()) { final Entry> entry = iterator.next(); final String key = new Path(entry.getKey()).toUri().getPath(); if (splitPath.equals(key) || splitPathWithNoSchema.equals(key)) { - final ArrayList list = entry.getValue(); - for (final String val : list) { - aliases.add(val); - } + aliases.addAll(entry.getValue()); } } - for (final String alias : aliases) { - final Operator op = mapWork.getAliasToWork().get( - alias); + // Collect the needed columns from all the aliases and create ORed filter + // expression for the table. + boolean allColumnsNeeded = false; + boolean noFilters = false; + Set neededColumnIDs = new HashSet(); + List filterExprs = new ArrayList(); + RowSchema rowSchema = null; + + for(String alias : aliases) { + final Operator op = + mapWork.getAliasToWork().get(alias); if (op != null && op instanceof TableScanOperator) { - final TableScanOperator tableScan = (TableScanOperator) op; - - // push down projections - final List list = tableScan.getNeededColumnIDs(); + final TableScanOperator ts = (TableScanOperator) op; - if (list != null) { - ColumnProjectionUtils.appendReadColumnIDs(jobConf, list); + if (ts.getNeededColumnIDs() == null) { + allColumnsNeeded = true; } else { - ColumnProjectionUtils.setFullyReadColumns(jobConf); + neededColumnIDs.addAll(ts.getNeededColumnIDs()); } - pushFilters(jobConf, tableScan); + rowSchema = ts.getSchema(); + ExprNodeGenericFuncDesc filterExpr = + ts.getConf() == null ? null : ts.getConf().getFilterExpr(); + noFilters = filterExpr == null; // No filter if any TS has no filter expression + filterExprs.add(filterExpr); } } - } - private void pushFilters(final JobConf jobConf, final TableScanOperator tableScan) { + ExprNodeGenericFuncDesc tableFilterExpr = null; + if (!noFilters) { + try { + for (ExprNodeGenericFuncDesc filterExpr : filterExprs) { + if (tableFilterExpr == null ) { + tableFilterExpr = filterExpr; + } else { + tableFilterExpr = ExprNodeGenericFuncDesc.newInstance(new GenericUDFOPOr(), + Arrays.asList(tableFilterExpr, filterExpr)); + } + } + } catch(UDFArgumentException ex) { + LOG.debug("Turn off filtering due to " + ex); + tableFilterExpr = null; + } + } - final TableScanDesc scanDesc = tableScan.getConf(); - if (scanDesc == null) { - LOG.debug("Not pushing filters because TableScanDesc is null"); - return; + // push down projections + if (!allColumnsNeeded) { + ColumnProjectionUtils.appendReadColumnIDs(jobConf, new ArrayList(neededColumnIDs)); + } else { + ColumnProjectionUtils.setFullyReadColumns(jobConf); } + pushFilters(jobConf, rowSchema, tableFilterExpr); + } + + private void pushFilters(final JobConf jobConf, RowSchema rowSchema, ExprNodeGenericFuncDesc filterExpr) { // construct column name list for reference by filter push down - Utilities.setColumnNameList(jobConf, tableScan); + Utilities.setColumnNameList(jobConf, rowSchema); // push down filters - final ExprNodeGenericFuncDesc filterExpr = scanDesc.getFilterExpr(); if (filterExpr == null) { LOG.debug("Not pushing filters because FilterExpr is null"); return; diff --git a/ql/src/java/org/apache/hadoop/hive/ql/io/parquet/read/DataWritableReadSupport.java b/ql/src/java/org/apache/hadoop/hive/ql/io/parquet/read/DataWritableReadSupport.java index 53f3b72..5a77926 100644 --- a/ql/src/java/org/apache/hadoop/hive/ql/io/parquet/read/DataWritableReadSupport.java +++ b/ql/src/java/org/apache/hadoop/hive/ql/io/parquet/read/DataWritableReadSupport.java @@ -168,12 +168,19 @@ private static MessageType getSchemaByName(MessageType schema, List colN * * @param schema Message schema where to search for column names. * @param colNames List of column names. - * @param colIndexes List of column indexes. + * @param colIndexes List of column indexes. 'null' for all the columns. * @return A MessageType object of the column names found. */ private static MessageType getSchemaByIndex(MessageType schema, List colNames, List colIndexes) { List schemaTypes = new ArrayList(); + if (colIndexes == null) { + colIndexes = new ArrayList(); + for (int i = 0; i < colNames.size(); i++) { + colIndexes.add(i); + } + } + for (Integer i : colIndexes) { if (i < colNames.size()) { if (i < schema.getFieldCount()) { diff --git a/ql/src/test/queries/clientpositive/parquet_join2.q b/ql/src/test/queries/clientpositive/parquet_join2.q new file mode 100644 index 0000000..9d107c7 --- /dev/null +++ b/ql/src/test/queries/clientpositive/parquet_join2.q @@ -0,0 +1,14 @@ +set hive.optimize.index.filter = true; +set hive.auto.convert.join=false; + +CREATE TABLE tbl1(id INT) STORED AS PARQUET; +INSERT INTO tbl1 VALUES(1), (2); + +CREATE TABLE tbl2(id INT, value STRING) STORED AS PARQUET; +INSERT INTO tbl2 VALUES(1, 'value1'); +INSERT INTO tbl2 VALUES(1, 'value2'); + +select tbl1.id, t1.value, t2.value +FROM tbl1 +JOIN (SELECT * FROM tbl2 WHERE value='value1') t1 ON tbl1.id=t1.id +JOIN (SELECT * FROM tbl2 WHERE value='value2') t2 ON tbl1.id=t2.id; diff --git a/ql/src/test/results/clientpositive/parquet_join2.q.out b/ql/src/test/results/clientpositive/parquet_join2.q.out new file mode 100644 index 0000000..f25dcd8 --- /dev/null +++ b/ql/src/test/results/clientpositive/parquet_join2.q.out @@ -0,0 +1,62 @@ +PREHOOK: query: CREATE TABLE tbl1(id INT) STORED AS PARQUET +PREHOOK: type: CREATETABLE +PREHOOK: Output: database:default +PREHOOK: Output: default@tbl1 +POSTHOOK: query: CREATE TABLE tbl1(id INT) STORED AS PARQUET +POSTHOOK: type: CREATETABLE +POSTHOOK: Output: database:default +POSTHOOK: Output: default@tbl1 +PREHOOK: query: INSERT INTO tbl1 VALUES(1), (2) +PREHOOK: type: QUERY +PREHOOK: Input: default@values__tmp__table__1 +PREHOOK: Output: default@tbl1 +POSTHOOK: query: INSERT INTO tbl1 VALUES(1), (2) +POSTHOOK: type: QUERY +POSTHOOK: Input: default@values__tmp__table__1 +POSTHOOK: Output: default@tbl1 +POSTHOOK: Lineage: tbl1.id EXPRESSION [(values__tmp__table__1)values__tmp__table__1.FieldSchema(name:tmp_values_col1, type:string, comment:), ] +PREHOOK: query: CREATE TABLE tbl2(id INT, value STRING) STORED AS PARQUET +PREHOOK: type: CREATETABLE +PREHOOK: Output: database:default +PREHOOK: Output: default@tbl2 +POSTHOOK: query: CREATE TABLE tbl2(id INT, value STRING) STORED AS PARQUET +POSTHOOK: type: CREATETABLE +POSTHOOK: Output: database:default +POSTHOOK: Output: default@tbl2 +PREHOOK: query: INSERT INTO tbl2 VALUES(1, 'value1') +PREHOOK: type: QUERY +PREHOOK: Input: default@values__tmp__table__2 +PREHOOK: Output: default@tbl2 +POSTHOOK: query: INSERT INTO tbl2 VALUES(1, 'value1') +POSTHOOK: type: QUERY +POSTHOOK: Input: default@values__tmp__table__2 +POSTHOOK: Output: default@tbl2 +POSTHOOK: Lineage: tbl2.id EXPRESSION [(values__tmp__table__2)values__tmp__table__2.FieldSchema(name:tmp_values_col1, type:string, comment:), ] +POSTHOOK: Lineage: tbl2.value SIMPLE [(values__tmp__table__2)values__tmp__table__2.FieldSchema(name:tmp_values_col2, type:string, comment:), ] +PREHOOK: query: INSERT INTO tbl2 VALUES(1, 'value2') +PREHOOK: type: QUERY +PREHOOK: Input: default@values__tmp__table__3 +PREHOOK: Output: default@tbl2 +POSTHOOK: query: INSERT INTO tbl2 VALUES(1, 'value2') +POSTHOOK: type: QUERY +POSTHOOK: Input: default@values__tmp__table__3 +POSTHOOK: Output: default@tbl2 +POSTHOOK: Lineage: tbl2.id EXPRESSION [(values__tmp__table__3)values__tmp__table__3.FieldSchema(name:tmp_values_col1, type:string, comment:), ] +POSTHOOK: Lineage: tbl2.value SIMPLE [(values__tmp__table__3)values__tmp__table__3.FieldSchema(name:tmp_values_col2, type:string, comment:), ] +PREHOOK: query: select tbl1.id, t1.value, t2.value +FROM tbl1 +JOIN (SELECT * FROM tbl2 WHERE value='value1') t1 ON tbl1.id=t1.id +JOIN (SELECT * FROM tbl2 WHERE value='value2') t2 ON tbl1.id=t2.id +PREHOOK: type: QUERY +PREHOOK: Input: default@tbl1 +PREHOOK: Input: default@tbl2 +#### A masked pattern was here #### +POSTHOOK: query: select tbl1.id, t1.value, t2.value +FROM tbl1 +JOIN (SELECT * FROM tbl2 WHERE value='value1') t1 ON tbl1.id=t1.id +JOIN (SELECT * FROM tbl2 WHERE value='value2') t2 ON tbl1.id=t2.id +POSTHOOK: type: QUERY +POSTHOOK: Input: default@tbl1 +POSTHOOK: Input: default@tbl2 +#### A masked pattern was here #### +1 value1 value2 diff --git a/storage-api/src/java/org/apache/hadoop/hive/ql/io/sarg/ExpressionTree.java b/storage-api/src/java/org/apache/hadoop/hive/ql/io/sarg/ExpressionTree.java index 577d95d..443083d 100644 --- a/storage-api/src/java/org/apache/hadoop/hive/ql/io/sarg/ExpressionTree.java +++ b/storage-api/src/java/org/apache/hadoop/hive/ql/io/sarg/ExpressionTree.java @@ -31,7 +31,7 @@ public enum Operator {OR, AND, NOT, LEAF, CONSTANT} private final Operator operator; private final List children; - private final int leaf; + private int leaf; private final SearchArgument.TruthValue constant; ExpressionTree() { @@ -153,4 +153,8 @@ public Operator getOperator() { public int getLeaf() { return leaf; } + + public void setLeaf(int leaf) { + this.leaf = leaf; + } } diff --git a/storage-api/src/java/org/apache/hadoop/hive/ql/io/sarg/SearchArgumentImpl.java b/storage-api/src/java/org/apache/hadoop/hive/ql/io/sarg/SearchArgumentImpl.java index eeff131..be5e67b 100644 --- a/storage-api/src/java/org/apache/hadoop/hive/ql/io/sarg/SearchArgumentImpl.java +++ b/storage-api/src/java/org/apache/hadoop/hive/ql/io/sarg/SearchArgumentImpl.java @@ -24,8 +24,12 @@ import java.util.Arrays; import java.util.Deque; import java.util.HashMap; +import java.util.HashSet; +import java.util.LinkedList; import java.util.List; import java.util.Map; +import java.util.Queue; +import java.util.Set; /** * The implementation of SearchArguments. @@ -429,15 +433,28 @@ static int compactLeaves(ExpressionTree tree, int next, int[] leafReorder) { * @return the fixed root */ static ExpressionTree rewriteLeaves(ExpressionTree root, - int[] leafReorder) { - if (root.getOperator() == ExpressionTree.Operator.LEAF) { - return new ExpressionTree(leafReorder[root.getLeaf()]); - } else if (root.getChildren() != null){ - List children = root.getChildren(); - for(int i=0; i < children.size(); ++i) { - children.set(i, rewriteLeaves(children.get(i), leafReorder)); + int[] leafReorder) { + // The leaves could be shared in the tree. Use Set to remove the duplicates. + Set leaves = new HashSet(); + Queue nodes = new LinkedList(); + nodes.add(root); + + while(!nodes.isEmpty()) { + ExpressionTree node = nodes.remove(); + if (node.getOperator() == ExpressionTree.Operator.LEAF) { + leaves.add(node); + } else { + if (node.getChildren() != null){ + nodes.addAll(node.getChildren()); + } } } + + // Update the leaf in place + for(ExpressionTree leaf : leaves) { + leaf.setLeaf(leafReorder[leaf.getLeaf()]); + } + return root; }