diff --git ql/src/java/org/apache/hadoop/hive/ql/optimizer/calcite/rules/HiveJoinPushTransitivePredicatesRule.java ql/src/java/org/apache/hadoop/hive/ql/optimizer/calcite/rules/HiveJoinPushTransitivePredicatesRule.java index 9cdc5e942c..1f186d58bc 100644 --- ql/src/java/org/apache/hadoop/hive/ql/optimizer/calcite/rules/HiveJoinPushTransitivePredicatesRule.java +++ ql/src/java/org/apache/hadoop/hive/ql/optimizer/calcite/rules/HiveJoinPushTransitivePredicatesRule.java @@ -28,7 +28,6 @@ import org.apache.calcite.rel.RelNode; import org.apache.calcite.rel.core.Join; import org.apache.calcite.rel.core.RelFactories.FilterFactory; -import org.apache.calcite.rel.metadata.RelMetadataQuery; import org.apache.calcite.rel.type.RelDataType; import org.apache.calcite.rel.type.RelDataTypeField; import org.apache.calcite.rex.RexBuilder; @@ -37,6 +36,7 @@ import org.apache.calcite.rex.RexNode; import org.apache.calcite.rex.RexUtil; import org.apache.calcite.rex.RexVisitorImpl; +import org.apache.calcite.sql.type.SqlTypeName; import org.apache.calcite.util.Util; import org.apache.hadoop.hive.ql.exec.Description; import org.apache.hadoop.hive.ql.optimizer.calcite.HiveCalciteUtil; @@ -181,6 +181,25 @@ public Void visitCall(RexCall call) { } return super.visitCall(call); } + + @Override + public Void visitInputRef(RexInputRef inputRef) { + if (!areTypesCompatible(inputRef.getType(), types.get(inputRef.getIndex()).getType())) { + throw new Util.FoundOne(inputRef); + } + return super.visitInputRef(inputRef); + } + + private boolean areTypesCompatible(RelDataType type1, RelDataType type2) { + if (type1.equals(type2)) { + return true; + } + SqlTypeName sqlType1 = type1.getSqlTypeName(); + if (sqlType1 != null) { + return sqlType1.equals(type2.getSqlTypeName()); + } + return false; + } } } diff --git ql/src/test/queries/clientpositive/optimize_join_ptp.q ql/src/test/queries/clientpositive/optimize_join_ptp.q new file mode 100644 index 0000000000..5807ec350d --- /dev/null +++ ql/src/test/queries/clientpositive/optimize_join_ptp.q @@ -0,0 +1,16 @@ +set hive.mapred.mode=nonstrict; +set hive.explain.user=false; + +create table t1 (v string, k int); +insert into t1 values ('people', 10), ('strangers', 20), ('parents', 30); + +create table t2 (v string, k double); +insert into t2 values ('people', 10), ('strangers', 20), ('parents', 30); + +-- should not produce exceptions +explain +select * from t1 where t1.k in (select t2.k from t2 where t2.v='people') and t1.k<15; + +select * from t1 where t1.k in (select t2.k from t2 where t2.v='people') and t1.k<15; + + diff --git ql/src/test/results/clientpositive/optimize_join_ptp.q.out ql/src/test/results/clientpositive/optimize_join_ptp.q.out new file mode 100644 index 0000000000..0ac9eab880 --- /dev/null +++ ql/src/test/results/clientpositive/optimize_join_ptp.q.out @@ -0,0 +1,116 @@ +PREHOOK: query: create table t1 (v string, k int) +PREHOOK: type: CREATETABLE +PREHOOK: Output: database:default +PREHOOK: Output: default@t1 +POSTHOOK: query: create table t1 (v string, k int) +POSTHOOK: type: CREATETABLE +POSTHOOK: Output: database:default +POSTHOOK: Output: default@t1 +PREHOOK: query: insert into t1 values ('people', 10), ('strangers', 20), ('parents', 30) +PREHOOK: type: QUERY +PREHOOK: Output: default@t1 +POSTHOOK: query: insert into t1 values ('people', 10), ('strangers', 20), ('parents', 30) +POSTHOOK: type: QUERY +POSTHOOK: Output: default@t1 +POSTHOOK: Lineage: t1.k EXPRESSION [(values__tmp__table__1)values__tmp__table__1.FieldSchema(name:tmp_values_col2, type:string, comment:), ] +POSTHOOK: Lineage: t1.v SIMPLE [(values__tmp__table__1)values__tmp__table__1.FieldSchema(name:tmp_values_col1, type:string, comment:), ] +PREHOOK: query: create table t2 (v string, k double) +PREHOOK: type: CREATETABLE +PREHOOK: Output: database:default +PREHOOK: Output: default@t2 +POSTHOOK: query: create table t2 (v string, k double) +POSTHOOK: type: CREATETABLE +POSTHOOK: Output: database:default +POSTHOOK: Output: default@t2 +PREHOOK: query: insert into t2 values ('people', 10), ('strangers', 20), ('parents', 30) +PREHOOK: type: QUERY +PREHOOK: Output: default@t2 +POSTHOOK: query: insert into t2 values ('people', 10), ('strangers', 20), ('parents', 30) +POSTHOOK: type: QUERY +POSTHOOK: Output: default@t2 +POSTHOOK: Lineage: t2.k EXPRESSION [(values__tmp__table__2)values__tmp__table__2.FieldSchema(name:tmp_values_col2, type:string, comment:), ] +POSTHOOK: Lineage: t2.v SIMPLE [(values__tmp__table__2)values__tmp__table__2.FieldSchema(name:tmp_values_col1, type:string, comment:), ] +PREHOOK: query: explain +select * from t1 where t1.k in (select t2.k from t2 where t2.v='people') and t1.k<15 +PREHOOK: type: QUERY +POSTHOOK: query: explain +select * from t1 where t1.k in (select t2.k from t2 where t2.v='people') and t1.k<15 +POSTHOOK: type: QUERY +STAGE DEPENDENCIES: + Stage-1 is a root stage + Stage-0 depends on stages: Stage-1 + +STAGE PLANS: + Stage: Stage-1 + Map Reduce + Map Operator Tree: + TableScan + alias: t1 + Statistics: Num rows: 3 Data size: 31 Basic stats: COMPLETE Column stats: NONE + Filter Operator + predicate: (k < 15) (type: boolean) + Statistics: Num rows: 1 Data size: 10 Basic stats: COMPLETE Column stats: NONE + Select Operator + expressions: v (type: string), k (type: int) + outputColumnNames: _col0, _col1 + Statistics: Num rows: 1 Data size: 10 Basic stats: COMPLETE Column stats: NONE + Reduce Output Operator + key expressions: UDFToDouble(_col1) (type: double) + sort order: + + Map-reduce partition columns: UDFToDouble(_col1) (type: double) + Statistics: Num rows: 1 Data size: 10 Basic stats: COMPLETE Column stats: NONE + value expressions: _col0 (type: string), _col1 (type: int) + TableScan + alias: t2 + Statistics: Num rows: 3 Data size: 37 Basic stats: COMPLETE Column stats: NONE + Filter Operator + predicate: ((v = 'people') and k is not null) (type: boolean) + Statistics: Num rows: 1 Data size: 12 Basic stats: COMPLETE Column stats: NONE + Select Operator + expressions: k (type: double) + outputColumnNames: _col0 + Statistics: Num rows: 1 Data size: 12 Basic stats: COMPLETE Column stats: NONE + Group By Operator + keys: _col0 (type: double) + mode: hash + outputColumnNames: _col0 + Statistics: Num rows: 1 Data size: 12 Basic stats: COMPLETE Column stats: NONE + Reduce Output Operator + key expressions: _col0 (type: double) + sort order: + + Map-reduce partition columns: _col0 (type: double) + Statistics: Num rows: 1 Data size: 12 Basic stats: COMPLETE Column stats: NONE + Reduce Operator Tree: + Join Operator + condition map: + Left Semi Join 0 to 1 + keys: + 0 UDFToDouble(_col1) (type: double) + 1 _col0 (type: double) + outputColumnNames: _col0, _col1 + Statistics: Num rows: 1 Data size: 11 Basic stats: COMPLETE Column stats: NONE + File Output Operator + compressed: false + Statistics: Num rows: 1 Data size: 11 Basic stats: COMPLETE Column stats: NONE + table: + input format: org.apache.hadoop.mapred.SequenceFileInputFormat + output format: org.apache.hadoop.hive.ql.io.HiveSequenceFileOutputFormat + serde: org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe + + Stage: Stage-0 + Fetch Operator + limit: -1 + Processor Tree: + ListSink + +PREHOOK: query: select * from t1 where t1.k in (select t2.k from t2 where t2.v='people') and t1.k<15 +PREHOOK: type: QUERY +PREHOOK: Input: default@t1 +PREHOOK: Input: default@t2 +#### A masked pattern was here #### +POSTHOOK: query: select * from t1 where t1.k in (select t2.k from t2 where t2.v='people') and t1.k<15 +POSTHOOK: type: QUERY +POSTHOOK: Input: default@t1 +POSTHOOK: Input: default@t2 +#### A masked pattern was here #### +people 10