diff --git a/ql/src/java/org/apache/hadoop/hive/ql/optimizer/calcite/rules/jdbc/JDBCProjectPushDownRule.java b/ql/src/java/org/apache/hadoop/hive/ql/optimizer/calcite/rules/jdbc/JDBCProjectPushDownRule.java index 920518aa9f..490c3e74ba 100644 --- a/ql/src/java/org/apache/hadoop/hive/ql/optimizer/calcite/rules/jdbc/JDBCProjectPushDownRule.java +++ b/ql/src/java/org/apache/hadoop/hive/ql/optimizer/calcite/rules/jdbc/JDBCProjectPushDownRule.java @@ -17,15 +17,20 @@ */ package org.apache.hadoop.hive.ql.optimizer.calcite.rules.jdbc; -import java.util.Arrays; - import org.apache.calcite.adapter.jdbc.JdbcRules.JdbcProject; -import org.apache.calcite.adapter.jdbc.JdbcRules.JdbcProjectRule; import org.apache.calcite.plan.RelOptRule; import org.apache.calcite.plan.RelOptRuleCall; -import org.apache.calcite.rel.RelNode; -import org.apache.calcite.rel.core.Project; +import org.apache.calcite.rex.RexCall; +import org.apache.calcite.rex.RexFieldAccess; +import org.apache.calcite.rex.RexInputRef; +import org.apache.calcite.rex.RexLiteral; +import org.apache.calcite.rex.RexLocalRef; import org.apache.calcite.rex.RexNode; +import org.apache.calcite.rex.RexOver; +import org.apache.calcite.rex.RexVisitor; +import org.apache.calcite.rex.RexVisitorImpl; +import org.apache.calcite.sql.SqlDialect; +import org.apache.calcite.util.Util; import org.apache.hadoop.hive.ql.optimizer.calcite.reloperators.jdbc.HiveJdbcConverter; import org.apache.hadoop.hive.ql.optimizer.calcite.reloperators.HiveProject; import org.slf4j.Logger; @@ -56,6 +61,9 @@ public boolean matches(RelOptRuleCall call) { if (!JDBCRexCallValidator.isValidJdbcOperation(currProject, conv.getJdbcDialect())) { return false; } + if (!validDataType(conv.getJdbcDialect(), currProject)) { + return false; + } } return true; @@ -78,4 +86,70 @@ public void onMatch(RelOptRuleCall call) { call.transformTo(converter.copy(converter.getTraitSet(), jdbcProject)); } -}; + /** + * Returns whether a given expression contains only valid data types for this dialect. + */ + private static boolean validDataType(SqlDialect dialect, RexNode e) { + try { + RexVisitor visitor = new JdbcDataTypeValidatorVisitor(dialect); + e.accept(visitor); + return true; + } catch (Util.FoundOne ex) { + Util.swallow(ex, null); + return false; + } + } + + private static final class JdbcDataTypeValidatorVisitor extends RexVisitorImpl { + private final SqlDialect dialect; + + private JdbcDataTypeValidatorVisitor(SqlDialect dialect) { + super(true); + this.dialect = dialect; + } + + @Override public Void visitInputRef(RexInputRef inputRef) { + if (!dialect.supportsDataType(inputRef.getType())) { + throw Util.FoundOne.NULL; + } + return super.visitInputRef(inputRef); + } + + @Override public Void visitLocalRef(RexLocalRef localRef) { + if (!dialect.supportsDataType(localRef.getType())) { + throw Util.FoundOne.NULL; + } + return super.visitLocalRef(localRef); + } + + @Override public Void visitLiteral(RexLiteral literal) { + if (!dialect.supportsDataType(literal.getType())) { + throw Util.FoundOne.NULL; + } + return super.visitLiteral(literal); + } + + @Override public Void visitCall(RexCall call) { + if (!dialect.supportsDataType(call.getType())) { + throw Util.FoundOne.NULL; + } + return super.visitCall(call); + } + + @Override public Void visitOver(RexOver over) { + if (!dialect.supportsDataType(over.getType())) { + throw Util.FoundOne.NULL; + } + return super.visitOver(over); + } + + @Override public Void visitFieldAccess(RexFieldAccess fieldAccess) { + if (!dialect.supportsDataType(fieldAccess.getType())) { + throw Util.FoundOne.NULL; + } + return super.visitFieldAccess(fieldAccess); + } + + } + +}