diff --git ql/src/java/org/apache/hadoop/hive/ql/optimizer/calcite/rules/HiveRelFieldTrimmer.java ql/src/java/org/apache/hadoop/hive/ql/optimizer/calcite/rules/HiveRelFieldTrimmer.java index 9bf42ed384..d218face89 100644 --- ql/src/java/org/apache/hadoop/hive/ql/optimizer/calcite/rules/HiveRelFieldTrimmer.java +++ ql/src/java/org/apache/hadoop/hive/ql/optimizer/calcite/rules/HiveRelFieldTrimmer.java @@ -36,6 +36,7 @@ import org.apache.calcite.rel.core.AggregateCall; import org.apache.calcite.rel.core.CorrelationId; import org.apache.calcite.rel.core.Project; +import org.apache.calcite.rel.core.TableFunctionScan; import org.apache.calcite.rel.core.TableScan; import org.apache.calcite.rel.metadata.RelMetadataQuery; import org.apache.calcite.rel.type.RelDataType; @@ -68,6 +69,7 @@ import org.apache.hadoop.hive.ql.optimizer.calcite.reloperators.HiveAggregate; import org.apache.hadoop.hive.ql.optimizer.calcite.reloperators.HiveMultiJoin; import org.apache.hadoop.hive.ql.optimizer.calcite.reloperators.HiveProject; +import org.apache.hadoop.hive.ql.optimizer.calcite.reloperators.HiveTableFunctionScan; import org.apache.hadoop.hive.ql.optimizer.calcite.reloperators.HiveTableScan; import org.apache.hadoop.hive.ql.parse.ColumnAccessInfo; import org.slf4j.Logger; @@ -745,4 +747,46 @@ private void fetchColStats(RelNode key, TableScan tableAccessRel, ImmutableBitSe protected TrimResult result(RelNode r, final Mapping mapping) { return new TrimResult(r, mapping); } + + /** + * Variant of {@link #trimFields(RelNode, ImmutableBitSet, Set)} for {@link HiveTableFunctionScan}. + * Copied {@link org.apache.calcite.sql2rel.RelFieldTrimmer#trimFields( + * org.apache.calcite.rel.logical.LogicalTableFunctionScan, ImmutableBitSet, Set)} + * and replaced tabFun to {@link HiveTableFunctionScan}. + * Proper fix would be implement this in calcite. + */ + public TrimResult trimFields( + HiveTableFunctionScan tabFun, + ImmutableBitSet fieldsUsed, + Set extraFields) { + final RelDataType rowType = tabFun.getRowType(); + final int fieldCount = rowType.getFieldCount(); + final List newInputs = new ArrayList<>(); + + for (RelNode input : tabFun.getInputs()) { + final int inputFieldCount = input.getRowType().getFieldCount(); + ImmutableBitSet inputFieldsUsed = ImmutableBitSet.range(inputFieldCount); + + // Create input with trimmed columns. + final Set inputExtraFields = + Collections.emptySet(); + TrimResult trimResult = + trimChildRestore( + tabFun, input, inputFieldsUsed, inputExtraFields); + assert trimResult.right.isIdentity(); + newInputs.add(trimResult.left); + } + + TableFunctionScan newTabFun = tabFun; + if (!tabFun.getInputs().equals(newInputs)) { + newTabFun = tabFun.copy(tabFun.getTraitSet(), newInputs, + tabFun.getCall(), tabFun.getElementType(), tabFun.getRowType(), + tabFun.getColumnMappings()); + } + assert newTabFun.getClass() == tabFun.getClass(); + + // Always project all fields. + Mapping mapping = Mappings.createIdentity(fieldCount); + return result(newTabFun, mapping); + } } diff --git ql/src/test/results/clientpositive/except_all.q.out ql/src/test/results/clientpositive/except_all.q.out index 020cba4287..5d1dc2211a 100644 --- ql/src/test/results/clientpositive/except_all.q.out +++ ql/src/test/results/clientpositive/except_all.q.out @@ -276,10 +276,10 @@ STAGE PLANS: Statistics: Num rows: 500 Data size: 97000 Basic stats: COMPLETE Column stats: COMPLETE Select Operator expressions: _col0 (type: string), _col1 (type: string), _col3 (type: bigint), (_col2 * _col3) (type: bigint) - outputColumnNames: _col0, _col1, _col3, _col4 + outputColumnNames: _col0, _col1, _col2, _col3 Statistics: Num rows: 500 Data size: 97000 Basic stats: COMPLETE Column stats: COMPLETE Group By Operator - aggregations: sum(_col4), sum(_col3) + aggregations: sum(_col3), sum(_col2) keys: _col0 (type: string), _col1 (type: string) minReductionHashAggr: 0.99 mode: hash @@ -297,10 +297,10 @@ STAGE PLANS: Statistics: Num rows: 500 Data size: 97000 Basic stats: COMPLETE Column stats: COMPLETE Select Operator expressions: _col0 (type: string), _col1 (type: string), _col3 (type: bigint), (_col2 * _col3) (type: bigint) - outputColumnNames: _col0, _col1, _col3, _col4 + outputColumnNames: _col0, _col1, _col2, _col3 Statistics: Num rows: 500 Data size: 97000 Basic stats: COMPLETE Column stats: COMPLETE Group By Operator - aggregations: sum(_col4), sum(_col3) + aggregations: sum(_col3), sum(_col2) keys: _col0 (type: string), _col1 (type: string) minReductionHashAggr: 0.99 mode: hash @@ -467,10 +467,10 @@ STAGE PLANS: Statistics: Num rows: 500 Data size: 97000 Basic stats: COMPLETE Column stats: COMPLETE Select Operator expressions: _col0 (type: string), _col1 (type: string), _col3 (type: bigint), (_col2 * _col3) (type: bigint) - outputColumnNames: _col0, _col1, _col3, _col4 + outputColumnNames: _col0, _col1, _col2, _col3 Statistics: Num rows: 500 Data size: 97000 Basic stats: COMPLETE Column stats: COMPLETE Group By Operator - aggregations: sum(_col4), sum(_col3) + aggregations: sum(_col3), sum(_col2) keys: _col0 (type: string), _col1 (type: string) minReductionHashAggr: 0.99 mode: hash @@ -488,10 +488,10 @@ STAGE PLANS: Statistics: Num rows: 500 Data size: 97000 Basic stats: COMPLETE Column stats: COMPLETE Select Operator expressions: _col0 (type: string), _col1 (type: string), _col3 (type: bigint), (_col2 * _col3) (type: bigint) - outputColumnNames: _col0, _col1, _col3, _col4 + outputColumnNames: _col0, _col1, _col2, _col3 Statistics: Num rows: 500 Data size: 97000 Basic stats: COMPLETE Column stats: COMPLETE Group By Operator - aggregations: sum(_col4), sum(_col3) + aggregations: sum(_col3), sum(_col2) keys: _col0 (type: string), _col1 (type: string) minReductionHashAggr: 0.99 mode: hash diff --git ql/src/test/results/clientpositive/intersect_all_rj.q.out ql/src/test/results/clientpositive/intersect_all_rj.q.out index b8ff98ae79..427b841a1b 100644 --- ql/src/test/results/clientpositive/intersect_all_rj.q.out +++ ql/src/test/results/clientpositive/intersect_all_rj.q.out @@ -180,12 +180,12 @@ HiveProject($f0=[$1]) HiveAggregate(group=[{0}], agg#0=[count()]) HiveProject($f0=[$0]) HiveAggregate(group=[{0}]) - HiveProject($f0=[CASE(IS NOT NULL($7), $7, if($5, $8, $6))]) - HiveJoin(condition=[>=($1, $13)], joinType=[inner], algorithm=[none], cost=[not available]) - HiveProject(int_col_10=[$0], bigint_col_3=[$1], BLOCK__OFFSET__INSIDE__FILE=[$2], INPUT__FILE__NAME=[$3], CAST=[CAST($4):RecordType(BIGINT writeid, INTEGER bucketid, BIGINT rowid)]) + HiveProject($f0=[CASE(IS NOT NULL($3), $3, if($1, $4, $2))]) + HiveJoin(condition=[>=($0, $5)], joinType=[inner], algorithm=[none], cost=[not available]) + HiveProject(bigint_col_3=[$1]) HiveFilter(condition=[IS NOT NULL($1)]) HiveTableScan(table=[[default, table_7]], table:alias=[a3]) - HiveProject(boolean_col_16=[$0], timestamp_col_5=[$1], timestamp_col_15=[$2], timestamp_col_30=[$3], int_col_18=[$4], BLOCK__OFFSET__INSIDE__FILE=[$5], INPUT__FILE__NAME=[$6], ROW__ID=[$7], CAST=[CAST($4):BIGINT]) + HiveProject(boolean_col_16=[$0], timestamp_col_5=[$1], timestamp_col_15=[$2], timestamp_col_30=[$3], CAST=[CAST($4):BIGINT]) HiveFilter(condition=[IS NOT NULL(CAST($4):BIGINT)]) HiveTableScan(table=[[default, table_10]], table:alias=[a4]) HiveProject($f0=[$0], $f1=[$1]) diff --git ql/src/test/results/clientpositive/llap/intersect_all_rj.q.out ql/src/test/results/clientpositive/llap/intersect_all_rj.q.out index cdfbc2239e..c47452fabd 100644 --- ql/src/test/results/clientpositive/llap/intersect_all_rj.q.out +++ ql/src/test/results/clientpositive/llap/intersect_all_rj.q.out @@ -180,12 +180,12 @@ HiveProject($f0=[$1]) HiveAggregate(group=[{0}], agg#0=[count()]) HiveProject($f0=[$0]) HiveAggregate(group=[{0}]) - HiveProject($f0=[CASE(IS NOT NULL($7), $7, if($5, $8, $6))]) - HiveJoin(condition=[>=($1, $13)], joinType=[inner], algorithm=[none], cost=[not available]) - HiveProject(int_col_10=[$0], bigint_col_3=[$1], BLOCK__OFFSET__INSIDE__FILE=[$2], INPUT__FILE__NAME=[$3], CAST=[CAST($4):RecordType(BIGINT writeid, INTEGER bucketid, BIGINT rowid)]) + HiveProject($f0=[CASE(IS NOT NULL($3), $3, if($1, $4, $2))]) + HiveJoin(condition=[>=($0, $5)], joinType=[inner], algorithm=[none], cost=[not available]) + HiveProject(bigint_col_3=[$1]) HiveFilter(condition=[IS NOT NULL($1)]) HiveTableScan(table=[[default, table_7]], table:alias=[a3]) - HiveProject(boolean_col_16=[$0], timestamp_col_5=[$1], timestamp_col_15=[$2], timestamp_col_30=[$3], int_col_18=[$4], BLOCK__OFFSET__INSIDE__FILE=[$5], INPUT__FILE__NAME=[$6], ROW__ID=[$7], CAST=[CAST($4):BIGINT]) + HiveProject(boolean_col_16=[$0], timestamp_col_5=[$1], timestamp_col_15=[$2], timestamp_col_30=[$3], CAST=[CAST($4):BIGINT]) HiveFilter(condition=[IS NOT NULL(CAST($4):BIGINT)]) HiveTableScan(table=[[default, table_10]], table:alias=[a4]) HiveProject($f0=[$0], $f1=[$1])