diff --git a/ql/src/java/org/apache/hadoop/hive/ql/exec/vector/VectorizationContext.java b/ql/src/java/org/apache/hadoop/hive/ql/exec/vector/VectorizationContext.java index 55d2a16f03..6ca1248543 100644 --- a/ql/src/java/org/apache/hadoop/hive/ql/exec/vector/VectorizationContext.java +++ b/ql/src/java/org/apache/hadoop/hive/ql/exec/vector/VectorizationContext.java @@ -840,22 +840,11 @@ public VectorExpression getVectorExpression(ExprNodeDesc exprDesc, VectorExpress } break; case ALL: - // Check if this is UDF for _bucket_number - if (expr.getGenericUDF() instanceof GenericUDFBucketNumber) { - if (LOG.isDebugEnabled()) { - LOG.debug("UDF to handle _bucket_number : Create BucketNumExpression"); - } - int outCol = ocm.allocateOutputColumn(exprDesc.getTypeInfo()); - ve = new BucketNumExpression(outCol); - ve.setInputTypeInfos(exprDesc.getTypeInfo()); - ve.setOutputTypeInfo(exprDesc.getTypeInfo()); - } else { - if (LOG.isDebugEnabled()) { - LOG.debug("We will try to use the VectorUDFAdaptor for " + exprDesc.toString() - + " because hive.vectorized.adaptor.usage.mode=all"); - } - ve = getCustomUDFExpression(expr, mode); + if (LOG.isDebugEnabled()) { + LOG.debug("We will try to use the VectorUDFAdaptor for " + exprDesc.toString() + + " because hive.vectorized.adaptor.usage.mode=all"); } + ve = getCustomUDFExpression(expr, mode); break; default: throw new RuntimeException("Unknown hive vector adaptor usage mode " + @@ -2124,6 +2113,11 @@ private VectorExpression getGenericUdfVectorExpression(GenericUDF udf, ve = getCastToTimestamp((GenericUDFTimestamp)udf, childExpr, mode, returnType); } else if (udf instanceof GenericUDFDate || udf instanceof GenericUDFToDate) { ve = getIdentityForDateToDate(childExpr, returnType); + } else if (udf instanceof GenericUDFBucketNumber) { + int outCol = ocm.allocateOutputColumn(returnType); + ve = new BucketNumExpression(outCol); + ve.setInputTypeInfos(returnType); + ve.setOutputTypeInfo(returnType); } if (ve != null) { return ve; diff --git a/ql/src/java/org/apache/hadoop/hive/ql/exec/vector/expressions/BucketNumExpression.java b/ql/src/java/org/apache/hadoop/hive/ql/exec/vector/expressions/BucketNumExpression.java index d8c696c302..c4ce4c595b 100644 --- a/ql/src/java/org/apache/hadoop/hive/ql/exec/vector/expressions/BucketNumExpression.java +++ b/ql/src/java/org/apache/hadoop/hive/ql/exec/vector/expressions/BucketNumExpression.java @@ -23,8 +23,6 @@ import org.apache.hadoop.hive.ql.exec.vector.VectorizedRowBatch; import org.apache.hadoop.hive.ql.metadata.HiveException; -import java.nio.ByteBuffer; - /** * An expression representing _bucket_number. */ @@ -32,6 +30,8 @@ private static final long serialVersionUID = 1L; private int rowNum = -1; private int bucketNum = -1; + private boolean rowSet = false; + private boolean bucketNumSet = false; public BucketNumExpression(int outputColNum) { super(outputColNum); @@ -43,19 +43,32 @@ public void initBuffer(VectorizedRowBatch batch) { cv.initBuffer(); } - public void setRowNum(final int rowNum) { + public void setRowNum(final int rowNum) throws HiveException{ this.rowNum = rowNum; + if (rowSet) { + throw new HiveException("Row number is already set"); + } + rowSet = true; } - public void setBucketNum(final int bucketNum) { + public void setBucketNum(final int bucketNum) throws HiveException{ this.bucketNum = bucketNum; + if (bucketNumSet) { + throw new HiveException("Bucket number is already set"); + } + bucketNumSet = true; } @Override public void evaluate(VectorizedRowBatch batch) throws HiveException { + if (!rowSet || !bucketNumSet) { + throw new HiveException("row number or bucket number is not set before evaluation"); + } BytesColumnVector cv = (BytesColumnVector) batch.cols[outputColumnNum]; String bucketNumStr = String.valueOf(bucketNum); cv.setVal(rowNum, bucketNumStr.getBytes(), 0, bucketNumStr.length()); + rowSet = false; + bucketNumSet = false; } @Override diff --git a/ql/src/java/org/apache/hadoop/hive/ql/exec/vector/reducesink/VectorReduceSinkObjectHashOperator.java b/ql/src/java/org/apache/hadoop/hive/ql/exec/vector/reducesink/VectorReduceSinkObjectHashOperator.java index 1a8395a71b..767df2161b 100644 --- a/ql/src/java/org/apache/hadoop/hive/ql/exec/vector/reducesink/VectorReduceSinkObjectHashOperator.java +++ b/ql/src/java/org/apache/hadoop/hive/ql/exec/vector/reducesink/VectorReduceSinkObjectHashOperator.java @@ -18,7 +18,6 @@ package org.apache.hadoop.hive.ql.exec.vector.reducesink; -import java.lang.reflect.Method; import java.util.Random; import java.util.function.BiFunction; @@ -88,8 +87,6 @@ private transient BiFunction hashFunc; private transient BucketNumExpression bucketExpr = null; - private transient Method buckectEvaluatorMethod; - /** Kryo ctor. */ protected VectorReduceSinkObjectHashOperator() { @@ -142,9 +139,6 @@ private void evaluateBucketExpr(VectorizedRowBatch batch, int rowNum, int bucket bucketExpr.evaluate(batch); } - private void evaluateBucketDummy(VectorizedRowBatch batch, int rowNum, int bucketNum) { - } - @Override protected void initializeOp(Configuration hconf) throws HiveException { super.initializeOp(hconf); @@ -191,21 +185,13 @@ protected void initializeOp(Configuration hconf) throws HiveException { ObjectInspectorUtils::getBucketHashCodeOld; // Set function to evaluate _bucket_number if needed. - try { - buckectEvaluatorMethod = this.getClass().getDeclaredMethod("evaluateBucketDummy", - VectorizedRowBatch.class, int.class, int.class); - if (reduceSinkKeyExpressions != null) { - for (VectorExpression ve : reduceSinkKeyExpressions) { - if (ve instanceof BucketNumExpression) { - bucketExpr = (BucketNumExpression) ve; - buckectEvaluatorMethod = this.getClass().getDeclaredMethod("evaluateBucketExpr", - VectorizedRowBatch.class, int.class, int.class); - break; - } + if (reduceSinkKeyExpressions != null) { + for (VectorExpression ve : reduceSinkKeyExpressions) { + if (ve instanceof BucketNumExpression) { + bucketExpr = (BucketNumExpression) ve; + break; } } - } catch (NoSuchMethodException e) { - throw new HiveException("Failed to find method to evaluate _bucket_number"); } } @@ -292,7 +278,9 @@ public void process(Object row, int tag) throws HiveException { final int bucketNum = ObjectInspectorUtils.getBucketNumber( hashFunc.apply(bucketFieldValues, bucketObjectInspectors), numBuckets); final int hashCode = nonPartitionRandom.nextInt() * 31 + bucketNum; - buckectEvaluatorMethod.invoke(this, batch, batchIndex, bucketNum); + if (bucketExpr != null) { + evaluateBucketExpr(batch, batchIndex, bucketNum); + } postProcess(batch, batchIndex, tag, hashCode); } } else { // isEmptyPartition = false @@ -303,7 +291,9 @@ public void process(Object row, int tag) throws HiveException { final int bucketNum = ObjectInspectorUtils.getBucketNumber( hashFunc.apply(bucketFieldValues, bucketObjectInspectors), numBuckets); final int hashCode = hashFunc.apply(partitionFieldValues, partitionObjectInspectors) * 31 + bucketNum; - buckectEvaluatorMethod.invoke(this, batch, batchIndex, bucketNum); + if (bucketExpr != null) { + evaluateBucketExpr(batch, batchIndex, bucketNum); + } postProcess(batch, batchIndex, tag, hashCode); } }