diff --git ql/src/java/org/apache/hadoop/hive/ql/udf/generic/DecimalNumDistinctValueEstimator.java ql/src/java/org/apache/hadoop/hive/ql/udf/generic/DecimalNumDistinctValueEstimator.java new file mode 100644 index 0000000..a05906e --- /dev/null +++ ql/src/java/org/apache/hadoop/hive/ql/udf/generic/DecimalNumDistinctValueEstimator.java @@ -0,0 +1,42 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.hadoop.hive.ql.udf.generic; + +import org.apache.hadoop.hive.common.type.HiveDecimal; + +public class DecimalNumDistinctValueEstimator extends NumDistinctValueEstimator { + + public DecimalNumDistinctValueEstimator(int numBitVectors) { + super(numBitVectors); + } + + public DecimalNumDistinctValueEstimator(String s, int numBitVectors) { + super(s, numBitVectors); + } + + public void addToEstimator(HiveDecimal decimal) { + int v = decimal.hashCode(); + super.addToEstimator(v); + } + + public void addToEstimatorPCSA(HiveDecimal decimal) { + int v = decimal.hashCode(); + super.addToEstimatorPCSA(v); + } +} diff --git ql/src/java/org/apache/hadoop/hive/ql/udf/generic/GenericUDAFComputeStats.java ql/src/java/org/apache/hadoop/hive/ql/udf/generic/GenericUDAFComputeStats.java index 7348478..3b063eb 100644 --- ql/src/java/org/apache/hadoop/hive/ql/udf/generic/GenericUDAFComputeStats.java +++ ql/src/java/org/apache/hadoop/hive/ql/udf/generic/GenericUDAFComputeStats.java @@ -17,28 +17,26 @@ */ package org.apache.hadoop.hive.ql.udf.generic; +import java.math.BigDecimal; import java.util.ArrayList; import java.util.List; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; +import org.apache.hadoop.hive.common.type.HiveDecimal; import org.apache.hadoop.hive.ql.exec.Description; import org.apache.hadoop.hive.ql.exec.UDFArgumentTypeException; import org.apache.hadoop.hive.ql.metadata.HiveException; import org.apache.hadoop.hive.ql.parse.SemanticException; import org.apache.hadoop.hive.ql.util.JavaDataModel; import org.apache.hadoop.hive.serde2.io.DoubleWritable; +import org.apache.hadoop.hive.serde2.io.HiveDecimalWritable; import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector; import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory; import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector; import org.apache.hadoop.hive.serde2.objectinspector.StructField; import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector; -import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory; -import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorUtils; -import org.apache.hadoop.hive.serde2.objectinspector.primitive.WritableDoubleObjectInspector; -import org.apache.hadoop.hive.serde2.objectinspector.primitive.WritableIntObjectInspector; -import org.apache.hadoop.hive.serde2.objectinspector.primitive.WritableLongObjectInspector; -import org.apache.hadoop.hive.serde2.objectinspector.primitive.WritableStringObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.primitive.*; import org.apache.hadoop.hive.serde2.typeinfo.PrimitiveTypeInfo; import org.apache.hadoop.hive.serde2.typeinfo.TypeInfo; import org.apache.hadoop.io.BytesWritable; @@ -88,9 +86,11 @@ public GenericUDAFEvaluator getEvaluator(TypeInfo[] parameters) return new GenericUDAFStringStatsEvaluator(); case BINARY: return new GenericUDAFBinaryStatsEvaluator(); + case DECIMAL: + return new GenericUDAFDecimalStatsEvaluator(); default: throw new UDFArgumentTypeException(0, - "Only integer/long/timestamp/float/double/string/binary/boolean type argument " + + "Only integer/long/timestamp/float/double/string/binary/boolean/decimal type argument " + "is accepted but " + parameters[0].getTypeName() + " is passed."); } @@ -1474,4 +1474,305 @@ public Object terminate(AggregationBuffer agg) throws HiveException { return result; } } + + public static class GenericUDAFDecimalStatsEvaluator extends GenericUDAFEvaluator { + + /* + * Object Inspector corresponding to the input parameter. + */ + private transient PrimitiveObjectInspector inputOI; + private transient PrimitiveObjectInspector numVectorsOI; + private final static int MAX_BIT_VECTORS = 1024; + + /* Partial aggregation result returned by TerminatePartial. Partial result is a struct + * containing a long field named "count". + */ + private transient Object[] partialResult; + + /* Object Inspectors corresponding to the struct returned by TerminatePartial and the long + * field within the struct - "count" + */ + private transient StructObjectInspector soi; + + private transient StructField minField; + private transient WritableHiveDecimalObjectInspector minFieldOI; + + private transient StructField maxField; + private transient WritableHiveDecimalObjectInspector maxFieldOI; + + private transient StructField countNullsField; + private transient WritableLongObjectInspector countNullsFieldOI; + + private transient StructField ndvField; + private transient WritableStringObjectInspector ndvFieldOI; + + private transient StructField numBitVectorsField; + private transient WritableIntObjectInspector numBitVectorsFieldOI; + + /* Output of final result of the aggregation + */ + private transient Object[] result; + + private boolean warned = false; + + @Override + public ObjectInspector init(Mode m, ObjectInspector[] parameters) throws HiveException { + super.init(m, parameters); + + // initialize input + if (mode == Mode.PARTIAL1 || mode == Mode.COMPLETE) { + inputOI = (PrimitiveObjectInspector) parameters[0]; + numVectorsOI = (PrimitiveObjectInspector) parameters[1]; + } else { + soi = (StructObjectInspector) parameters[0]; + + minField = soi.getStructFieldRef("Min"); + minFieldOI = (WritableHiveDecimalObjectInspector) minField.getFieldObjectInspector(); + + maxField = soi.getStructFieldRef("Max"); + maxFieldOI = (WritableHiveDecimalObjectInspector) maxField.getFieldObjectInspector(); + + countNullsField = soi.getStructFieldRef("CountNulls"); + countNullsFieldOI = (WritableLongObjectInspector) countNullsField.getFieldObjectInspector(); + + ndvField = soi.getStructFieldRef("BitVector"); + ndvFieldOI = (WritableStringObjectInspector) ndvField.getFieldObjectInspector(); + + numBitVectorsField = soi.getStructFieldRef("NumBitVectors"); + numBitVectorsFieldOI = (WritableIntObjectInspector) + numBitVectorsField.getFieldObjectInspector(); + } + + // initialize output + if (mode == Mode.PARTIAL1 || mode == Mode.PARTIAL2) { + List foi = new ArrayList(); + foi.add(PrimitiveObjectInspectorFactory.writableStringObjectInspector); + foi.add(PrimitiveObjectInspectorFactory.writableHiveDecimalObjectInspector); + foi.add(PrimitiveObjectInspectorFactory.writableHiveDecimalObjectInspector); + foi.add(PrimitiveObjectInspectorFactory.writableLongObjectInspector); + foi.add(PrimitiveObjectInspectorFactory.writableStringObjectInspector); + foi.add(PrimitiveObjectInspectorFactory.writableIntObjectInspector); + + List fname = new ArrayList(); + fname.add("ColumnType"); + fname.add("Min"); + fname.add("Max"); + fname.add("CountNulls"); + fname.add("BitVector"); + fname.add("NumBitVectors"); + + partialResult = new Object[6]; + partialResult[0] = new Text(); + partialResult[1] = new HiveDecimalWritable(HiveDecimal.create(0)); + partialResult[2] = new HiveDecimalWritable(HiveDecimal.create(0)); + partialResult[3] = new LongWritable(0); + partialResult[4] = new Text(); + partialResult[5] = new IntWritable(0); + + return ObjectInspectorFactory.getStandardStructObjectInspector(fname, + foi); + } else { + List foi = new ArrayList(); + foi.add(PrimitiveObjectInspectorFactory.writableStringObjectInspector); + foi.add(PrimitiveObjectInspectorFactory.writableHiveDecimalObjectInspector); + foi.add(PrimitiveObjectInspectorFactory.writableHiveDecimalObjectInspector); + foi.add(PrimitiveObjectInspectorFactory.writableLongObjectInspector); + foi.add(PrimitiveObjectInspectorFactory.writableLongObjectInspector); + + List fname = new ArrayList(); + fname.add("ColumnType"); + fname.add("Min"); + fname.add("Max"); + fname.add("CountNulls"); + fname.add("NumDistinctValues"); + + result = new Object[5]; + result[0] = new Text(); + result[1] = new HiveDecimalWritable(HiveDecimal.create(0)); + result[2] = new HiveDecimalWritable(HiveDecimal.create(0)); + result[3] = new LongWritable(0); + result[4] = new LongWritable(0); + + return ObjectInspectorFactory.getStandardStructObjectInspector(fname, + foi); + } + } + + @AggregationType(estimable = true) + public static class DecimalStatsAgg extends AbstractAggregationBuffer { + public String columnType; + public HiveDecimal min; /* Minimum value seen so far */ + public HiveDecimal max; /* Maximum value seen so far */ + public long countNulls; /* Count of number of null values seen so far */ + public DecimalNumDistinctValueEstimator numDV; /* Distinct value estimator */ + public boolean firstItem; /* First item in the aggBuf? */ + public int numBitVectors; + @Override + public int estimate() { + JavaDataModel model = JavaDataModel.get(); + return model.primitive1() * 2 + model.primitive2() + model.lengthOfDecimal() * 2 + + model.lengthFor(columnType) + model.lengthFor(numDV); + } + }; + + @Override + public AggregationBuffer getNewAggregationBuffer() throws HiveException { + DecimalStatsAgg result = new DecimalStatsAgg(); + reset(result); + return result; + } + + public void initNDVEstimator(DecimalStatsAgg aggBuffer, int numBitVectors) { + aggBuffer.numDV = new DecimalNumDistinctValueEstimator(numBitVectors); + aggBuffer.numDV.reset(); + } + + @Override + public void reset(AggregationBuffer agg) throws HiveException { + DecimalStatsAgg myagg = (DecimalStatsAgg) agg; + myagg.columnType = new String("Decimal"); + myagg.min = HiveDecimal.create(0); + myagg.max = HiveDecimal.create(0); + myagg.countNulls = 0; + myagg.firstItem = true; + } + + @Override + public void iterate(AggregationBuffer agg, Object[] parameters) throws HiveException { + Object p = parameters[0]; + DecimalStatsAgg myagg = (DecimalStatsAgg) agg; + boolean emptyTable = false; + + if (parameters[1] == null) { + emptyTable = true; + } + + if (myagg.firstItem) { + int numVectors = 0; + if (!emptyTable) { + numVectors = PrimitiveObjectInspectorUtils.getInt(parameters[1], numVectorsOI); + } + + if (numVectors > MAX_BIT_VECTORS) { + throw new HiveException("The maximum allowed value for number of bit vectors " + + " is " + MAX_BIT_VECTORS + ", but was passed " + numVectors + " bit vectors"); + } + + initNDVEstimator(myagg, numVectors); + myagg.firstItem = false; + myagg.numBitVectors = numVectors; + } + + if (!emptyTable) { + + //Update null counter if a null value is seen + if (p == null) { + myagg.countNulls++; + } + else { + try { + + HiveDecimal v = PrimitiveObjectInspectorUtils.getHiveDecimal(p, inputOI); + + //Update min counter if new value is less than min seen so far + if (v.compareTo(myagg.min) < 0) { + myagg.min = v; + } + + //Update max counter if new value is greater than max seen so far + if (v.compareTo(myagg.max) > 0) { + myagg.max = v; + } + + // Add value to NumDistinctValue Estimator + myagg.numDV.addToEstimator(v); + + } catch (NumberFormatException e) { + if (!warned) { + warned = true; + LOG.warn(getClass().getSimpleName() + " " + + StringUtils.stringifyException(e)); + LOG.warn(getClass().getSimpleName() + + " ignoring similar exceptions."); + } + } + } + } + } + + @Override + public Object terminatePartial(AggregationBuffer agg) throws HiveException { + DecimalStatsAgg myagg = (DecimalStatsAgg) agg; + + // Serialize numDistinctValue Estimator + Text t = myagg.numDV.serialize(); + + // Serialize the rest of the values in the AggBuffer + ((Text) partialResult[0]).set(myagg.columnType); + ((HiveDecimalWritable) partialResult[1]).set(myagg.min); + ((HiveDecimalWritable) partialResult[2]).set(myagg.max); + ((LongWritable) partialResult[3]).set(myagg.countNulls); + ((Text) partialResult[4]).set(t); + ((IntWritable) partialResult[5]).set(myagg.numBitVectors); + + return partialResult; + } + + @Override + public void merge(AggregationBuffer agg, Object partial) throws HiveException { + if (partial != null) { + DecimalStatsAgg myagg = (DecimalStatsAgg) agg; + + if (myagg.firstItem) { + Object partialValue = soi.getStructFieldData(partial, numBitVectorsField); + int numVectors = numBitVectorsFieldOI.get(partialValue); + initNDVEstimator(myagg, numVectors); + myagg.firstItem = false; + myagg.numBitVectors = numVectors; + } + + // Update min if min is lesser than the smallest value seen so far + Object partialValue = soi.getStructFieldData(partial, minField); + if (myagg.min.compareTo(minFieldOI.getPrimitiveJavaObject(partialValue)) > 0) { + myagg.min = minFieldOI.getPrimitiveJavaObject(partialValue); + } + + // Update max if max is greater than the largest value seen so far + partialValue = soi.getStructFieldData(partial, maxField); + if (myagg.max.compareTo(maxFieldOI.getPrimitiveJavaObject(partialValue)) < 0) { + myagg.max = maxFieldOI.getPrimitiveJavaObject(partialValue); + } + + // Update the null counter + partialValue = soi.getStructFieldData(partial, countNullsField); + myagg.countNulls += countNullsFieldOI.get(partialValue); + + // Merge numDistinctValue Estimators + partialValue = soi.getStructFieldData(partial, ndvField); + String v = ndvFieldOI.getPrimitiveJavaObject(partialValue); + + NumDistinctValueEstimator o = new NumDistinctValueEstimator(v, myagg.numBitVectors); + myagg.numDV.mergeEstimators(o); + } + } + + @Override + public Object terminate(AggregationBuffer agg) throws HiveException { + DecimalStatsAgg myagg = (DecimalStatsAgg) agg; + long numDV = 0; + + if (myagg.numBitVectors != 0) { + numDV = myagg.numDV.estimateNumDistinctValues(); + } + + // Serialize the result struct + ((Text) result[0]).set(myagg.columnType); + ((HiveDecimalWritable) result[1]).set(myagg.min); + ((HiveDecimalWritable) result[2]).set(myagg.max); + ((LongWritable) result[3]).set(myagg.countNulls); + ((LongWritable) result[4]).set(numDV); + + return result; + } + } }