diff --git ql/src/java/org/apache/hadoop/hive/ql/udf/generic/GenericUDF.java ql/src/java/org/apache/hadoop/hive/ql/udf/generic/GenericUDF.java index 259fde8..00a4f38 100644 --- ql/src/java/org/apache/hadoop/hive/ql/udf/generic/GenericUDF.java +++ ql/src/java/org/apache/hadoop/hive/ql/udf/generic/GenericUDF.java @@ -43,9 +43,11 @@ import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorUtils; import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector; import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector.PrimitiveCategory; +import org.apache.hadoop.hive.serde2.objectinspector.primitive.BooleanObjectInspector; import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory; import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorUtils; import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorUtils.PrimitiveGrouping; +import org.apache.hadoop.io.BooleanWritable; import org.apache.hadoop.io.IntWritable; import org.apache.hadoop.io.LongWritable; import org.apache.hive.common.util.DateUtils; @@ -530,6 +532,20 @@ protected String getConstantStringValue(ObjectInspector[] arguments, int i) { return str; } + protected Boolean getConstantBooleanValue(ObjectInspector[] arguments, int i) + throws UDFArgumentTypeException { + Object constValue = ((ConstantObjectInspector) arguments[i]).getWritableConstantValue(); + if (constValue == null) { + return false; + } + if (constValue instanceof BooleanWritable) { + return ((BooleanWritable) constValue).get(); + } else { + throw new UDFArgumentTypeException(i, getFuncName() + " only takes BOOLEAN types as " + + getArgOrder(i) + " argument, got " + constValue.getClass()); + } + } + protected Integer getConstantIntValue(ObjectInspector[] arguments, int i) throws UDFArgumentTypeException { Object constValue = ((ConstantObjectInspector) arguments[i]).getWritableConstantValue(); diff --git ql/src/java/org/apache/hadoop/hive/ql/udf/generic/GenericUDFMonthsBetween.java ql/src/java/org/apache/hadoop/hive/ql/udf/generic/GenericUDFMonthsBetween.java index 35dc51a..be9127a 100644 --- ql/src/java/org/apache/hadoop/hive/ql/udf/generic/GenericUDFMonthsBetween.java +++ ql/src/java/org/apache/hadoop/hive/ql/udf/generic/GenericUDFMonthsBetween.java @@ -35,6 +35,7 @@ import org.apache.hadoop.hive.ql.exec.UDFArgumentException; import org.apache.hadoop.hive.ql.metadata.HiveException; import org.apache.hadoop.hive.serde2.io.DoubleWritable; +import org.apache.hadoop.hive.serde2.objectinspector.ConstantObjectInspector; import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector; import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorConverters.Converter; import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector.PrimitiveCategory; @@ -44,7 +45,8 @@ * UDFMonthsBetween. * */ -@Description(name = "months_between", value = "_FUNC_(date1, date2) - returns number of months between dates date1 and date2", +@Description(name = "months_between", value = "_FUNC_(date1, date2, roundOff) " + + "- returns number of months between dates date1 and date2", extended = "If date1 is later than date2, then the result is positive. " + "If date1 is earlier than date2, then the result is negative. " + "If date1 and date2 are either the same days of the month or both last days of months, " @@ -53,7 +55,7 @@ + "month and considers the difference in time components date1 and date2.\n" + "date1 and date2 type can be date, timestamp or string in the format " + "'yyyy-MM-dd' or 'yyyy-MM-dd HH:mm:ss'. " - + "The result is rounded to 8 decimal places.\n" + + "The result is rounded to 8 decimal places by default. Set roundOff=false otherwise. \n" + " Example:\n" + " > SELECT _FUNC_('1997-02-28 10:30:00', '1996-10-30');\n 3.94959677") public class GenericUDFMonthsBetween extends GenericUDF { @@ -64,14 +66,21 @@ private final Calendar cal1 = Calendar.getInstance(); private final Calendar cal2 = Calendar.getInstance(); private final DoubleWritable output = new DoubleWritable(); + private boolean isRoundOffNeeded = true; @Override public ObjectInspector initialize(ObjectInspector[] arguments) throws UDFArgumentException { - checkArgsSize(arguments, 2, 2); + checkArgsSize(arguments, 2, 3); checkArgPrimitive(arguments, 0); checkArgPrimitive(arguments, 1); + if (arguments.length == 3) { + if (arguments[2] instanceof ConstantObjectInspector) { + isRoundOffNeeded = getConstantBooleanValue(arguments, 2); + } + } + // the function should support both short date and full timestamp format // time part of the timestamp should not be skipped checkArgGroups(arguments, 0, tsInputTypes, STRING_GROUP, DATE_GROUP); @@ -129,9 +138,11 @@ public Object evaluate(DeferredObject[] arguments) throws HiveException { // 1 sec is 0.000000373 months (1/2678400). 1 month is 31 days. // there should be no adjustments for leap seconds double monBtwDbl = monDiffInt + (sec1 - sec2) / 2678400D; - // Round a double to 8 decimal places. - double result = BigDecimal.valueOf(monBtwDbl).setScale(8, ROUND_HALF_UP).doubleValue(); - output.set(result); + if (isRoundOffNeeded) { + // Round a double to 8 decimal places. + monBtwDbl = BigDecimal.valueOf(monBtwDbl).setScale(8, ROUND_HALF_UP).doubleValue(); + } + output.set(monBtwDbl); return output; }