Index: data/files/covar_tab.txt =================================================================== --- data/files/covar_tab.txt (revision 0) +++ data/files/covar_tab.txt (revision 0) @@ -0,0 +1,6 @@ +1 15 +2 3 +3 7 12 +4 4 14 +5 8 17 +6 2 11 Index: ql/src/test/results/clientpositive/udaf_covar_pop.q.out =================================================================== --- ql/src/test/results/clientpositive/udaf_covar_pop.q.out (revision 0) +++ ql/src/test/results/clientpositive/udaf_covar_pop.q.out (revision 0) @@ -0,0 +1,80 @@ +PREHOOK: query: DROP TABLE covar_tab +PREHOOK: type: DROPTABLE +POSTHOOK: query: DROP TABLE covar_tab +POSTHOOK: type: DROPTABLE +PREHOOK: query: CREATE TABLE covar_tab (a INT, b INT, c INT) +ROW FORMAT DELIMITED FIELDS TERMINATED BY '\t' +STORED AS TEXTFILE +PREHOOK: type: CREATETABLE +POSTHOOK: query: CREATE TABLE covar_tab (a INT, b INT, c INT) +ROW FORMAT DELIMITED FIELDS TERMINATED BY '\t' +STORED AS TEXTFILE +POSTHOOK: type: CREATETABLE +POSTHOOK: Output: default@covar_tab +PREHOOK: query: LOAD DATA LOCAL INPATH '../data/files/covar_tab.txt' OVERWRITE +INTO TABLE covar_tab +PREHOOK: type: LOAD +POSTHOOK: query: LOAD DATA LOCAL INPATH '../data/files/covar_tab.txt' OVERWRITE +INTO TABLE covar_tab +POSTHOOK: type: LOAD +POSTHOOK: Output: default@covar_tab +PREHOOK: query: DESCRIBE FUNCTION cover_pop +PREHOOK: type: DESCFUNCTION +POSTHOOK: query: DESCRIBE FUNCTION cover_pop +POSTHOOK: type: DESCFUNCTION +Function 'cover_pop' does not exist. +PREHOOK: query: SELECT covar_pop(b, c) FROM covar_tab WHERE a < 1 +PREHOOK: type: QUERY +PREHOOK: Input: default@covar_tab +PREHOOK: Output: file:/tmp/hadoop/hive_2010-08-12_13-36-36_438_6823139854193249749/10000 +POSTHOOK: query: SELECT covar_pop(b, c) FROM covar_tab WHERE a < 1 +POSTHOOK: type: QUERY +POSTHOOK: Input: default@covar_tab +POSTHOOK: Output: file:/tmp/hadoop/hive_2010-08-12_13-36-36_438_6823139854193249749/10000 +NULL +PREHOOK: query: SELECT covar_pop(b, c) FROM covar_tab WHERE a < 3 +PREHOOK: type: QUERY +PREHOOK: Input: default@covar_tab +PREHOOK: Output: file:/tmp/hadoop/hive_2010-08-12_13-36-40_382_594310596383342223/10000 +POSTHOOK: query: SELECT covar_pop(b, c) FROM covar_tab WHERE a < 3 +POSTHOOK: type: QUERY +POSTHOOK: Input: default@covar_tab +POSTHOOK: Output: file:/tmp/hadoop/hive_2010-08-12_13-36-40_382_594310596383342223/10000 +NULL +PREHOOK: query: SELECT covar_pop(b, c) FROM covar_tab WHERE a = 3 +PREHOOK: type: QUERY +PREHOOK: Input: default@covar_tab +PREHOOK: Output: file:/tmp/hadoop/hive_2010-08-12_13-36-43_670_5972663068645833017/10000 +POSTHOOK: query: SELECT covar_pop(b, c) FROM covar_tab WHERE a = 3 +POSTHOOK: type: QUERY +POSTHOOK: Input: default@covar_tab +POSTHOOK: Output: file:/tmp/hadoop/hive_2010-08-12_13-36-43_670_5972663068645833017/10000 +0.0 +PREHOOK: query: SELECT a, covar_pop(b, c) FROM covar_tab GROUP BY a ORDER BY a +PREHOOK: type: QUERY +PREHOOK: Input: default@covar_tab +PREHOOK: Output: file:/tmp/hadoop/hive_2010-08-12_13-36-46_980_5327659877612256112/10000 +POSTHOOK: query: SELECT a, covar_pop(b, c) FROM covar_tab GROUP BY a ORDER BY a +POSTHOOK: type: QUERY +POSTHOOK: Input: default@covar_tab +POSTHOOK: Output: file:/tmp/hadoop/hive_2010-08-12_13-36-46_980_5327659877612256112/10000 +1 NULL +2 NULL +3 0.0 +4 0.0 +5 0.0 +6 0.0 +PREHOOK: query: SELECT covar_pop(b, c) FROM covar_tab +PREHOOK: type: QUERY +PREHOOK: Input: default@covar_tab +PREHOOK: Output: file:/tmp/hadoop/hive_2010-08-12_13-36-52_614_1667805970854382113/10000 +POSTHOOK: query: SELECT covar_pop(b, c) FROM covar_tab +POSTHOOK: type: QUERY +POSTHOOK: Input: default@covar_tab +POSTHOOK: Output: file:/tmp/hadoop/hive_2010-08-12_13-36-52_614_1667805970854382113/10000 +3.624999999999999 +PREHOOK: query: DROP TABLE covar_tab +PREHOOK: type: DROPTABLE +POSTHOOK: query: DROP TABLE covar_tab +POSTHOOK: type: DROPTABLE +POSTHOOK: Output: default@covar_tab Index: ql/src/test/results/clientpositive/show_functions.q.out =================================================================== --- ql/src/test/results/clientpositive/show_functions.q.out (revision 984986) +++ ql/src/test/results/clientpositive/show_functions.q.out (working copy) @@ -40,6 +40,8 @@ conv cos count +covar_pop +covar_samp date_add date_sub datediff @@ -165,6 +167,8 @@ conv cos count +covar_pop +covar_samp PREHOOK: query: SHOW FUNCTIONS '.*e$' PREHOOK: type: SHOWFUNCTIONS POSTHOOK: query: SHOW FUNCTIONS '.*e$' Index: ql/src/test/results/clientpositive/udaf_covar_samp.q.out =================================================================== --- ql/src/test/results/clientpositive/udaf_covar_samp.q.out (revision 0) +++ ql/src/test/results/clientpositive/udaf_covar_samp.q.out (revision 0) @@ -0,0 +1,80 @@ +PREHOOK: query: DROP TABLE covar_tab +PREHOOK: type: DROPTABLE +POSTHOOK: query: DROP TABLE covar_tab +POSTHOOK: type: DROPTABLE +PREHOOK: query: CREATE TABLE covar_tab (a INT, b INT, c INT) +ROW FORMAT DELIMITED FIELDS TERMINATED BY '\t' +STORED AS TEXTFILE +PREHOOK: type: CREATETABLE +POSTHOOK: query: CREATE TABLE covar_tab (a INT, b INT, c INT) +ROW FORMAT DELIMITED FIELDS TERMINATED BY '\t' +STORED AS TEXTFILE +POSTHOOK: type: CREATETABLE +POSTHOOK: Output: default@covar_tab +PREHOOK: query: LOAD DATA LOCAL INPATH '../data/files/covar_tab.txt' OVERWRITE +INTO TABLE covar_tab +PREHOOK: type: LOAD +POSTHOOK: query: LOAD DATA LOCAL INPATH '../data/files/covar_tab.txt' OVERWRITE +INTO TABLE covar_tab +POSTHOOK: type: LOAD +POSTHOOK: Output: default@covar_tab +PREHOOK: query: DESCRIBE FUNCTION cover_samp +PREHOOK: type: DESCFUNCTION +POSTHOOK: query: DESCRIBE FUNCTION cover_samp +POSTHOOK: type: DESCFUNCTION +Function 'cover_samp' does not exist. +PREHOOK: query: SELECT covar_samp(b, c) FROM covar_tab WHERE a < 1 +PREHOOK: type: QUERY +PREHOOK: Input: default@covar_tab +PREHOOK: Output: file:/tmp/hadoop/hive_2010-08-12_14-16-35_676_3641926801700375426/10000 +POSTHOOK: query: SELECT covar_samp(b, c) FROM covar_tab WHERE a < 1 +POSTHOOK: type: QUERY +POSTHOOK: Input: default@covar_tab +POSTHOOK: Output: file:/tmp/hadoop/hive_2010-08-12_14-16-35_676_3641926801700375426/10000 +NULL +PREHOOK: query: SELECT covar_samp(b, c) FROM covar_tab WHERE a < 3 +PREHOOK: type: QUERY +PREHOOK: Input: default@covar_tab +PREHOOK: Output: file:/tmp/hadoop/hive_2010-08-12_14-16-40_604_3037562346609947428/10000 +POSTHOOK: query: SELECT covar_samp(b, c) FROM covar_tab WHERE a < 3 +POSTHOOK: type: QUERY +POSTHOOK: Input: default@covar_tab +POSTHOOK: Output: file:/tmp/hadoop/hive_2010-08-12_14-16-40_604_3037562346609947428/10000 +NULL +PREHOOK: query: SELECT covar_samp(b, c) FROM covar_tab WHERE a = 3 +PREHOOK: type: QUERY +PREHOOK: Input: default@covar_tab +PREHOOK: Output: file:/tmp/hadoop/hive_2010-08-12_14-16-44_298_3268259707600760236/10000 +POSTHOOK: query: SELECT covar_samp(b, c) FROM covar_tab WHERE a = 3 +POSTHOOK: type: QUERY +POSTHOOK: Input: default@covar_tab +POSTHOOK: Output: file:/tmp/hadoop/hive_2010-08-12_14-16-44_298_3268259707600760236/10000 +0.0 +PREHOOK: query: SELECT a, covar_samp(b, c) FROM covar_tab GROUP BY a ORDER BY a +PREHOOK: type: QUERY +PREHOOK: Input: default@covar_tab +PREHOOK: Output: file:/tmp/hadoop/hive_2010-08-12_14-16-48_692_5070115246736637911/10000 +POSTHOOK: query: SELECT a, covar_samp(b, c) FROM covar_tab GROUP BY a ORDER BY a +POSTHOOK: type: QUERY +POSTHOOK: Input: default@covar_tab +POSTHOOK: Output: file:/tmp/hadoop/hive_2010-08-12_14-16-48_692_5070115246736637911/10000 +1 NULL +2 NULL +3 0.0 +4 0.0 +5 0.0 +6 0.0 +PREHOOK: query: SELECT covar_samp(b, c) FROM covar_tab +PREHOOK: type: QUERY +PREHOOK: Input: default@covar_tab +PREHOOK: Output: file:/tmp/hadoop/hive_2010-08-12_14-16-54_583_4265603027362600529/10000 +POSTHOOK: query: SELECT covar_samp(b, c) FROM covar_tab +POSTHOOK: type: QUERY +POSTHOOK: Input: default@covar_tab +POSTHOOK: Output: file:/tmp/hadoop/hive_2010-08-12_14-16-54_583_4265603027362600529/10000 +4.833333333333332 +PREHOOK: query: DROP TABLE covar_tab +PREHOOK: type: DROPTABLE +POSTHOOK: query: DROP TABLE covar_tab +POSTHOOK: type: DROPTABLE +POSTHOOK: Output: default@covar_tab Index: ql/src/test/queries/clientpositive/udaf_covar_samp.q =================================================================== --- ql/src/test/queries/clientpositive/udaf_covar_samp.q (revision 0) +++ ql/src/test/queries/clientpositive/udaf_covar_samp.q (revision 0) @@ -0,0 +1,15 @@ +DROP TABLE covar_tab; +CREATE TABLE covar_tab (a INT, b INT, c INT) +ROW FORMAT DELIMITED FIELDS TERMINATED BY '\t' +STORED AS TEXTFILE; +LOAD DATA LOCAL INPATH '../data/files/covar_tab.txt' OVERWRITE +INTO TABLE covar_tab; + +DESCRIBE FUNCTION cover_samp; +SELECT covar_samp(b, c) FROM covar_tab WHERE a < 1; +SELECT covar_samp(b, c) FROM covar_tab WHERE a < 3; +SELECT covar_samp(b, c) FROM covar_tab WHERE a = 3; +SELECT a, covar_samp(b, c) FROM covar_tab GROUP BY a ORDER BY a; +SELECT covar_samp(b, c) FROM covar_tab; + +DROP TABLE covar_tab; Index: ql/src/test/queries/clientpositive/udaf_covar_pop.q =================================================================== --- ql/src/test/queries/clientpositive/udaf_covar_pop.q (revision 0) +++ ql/src/test/queries/clientpositive/udaf_covar_pop.q (revision 0) @@ -0,0 +1,15 @@ +DROP TABLE covar_tab; +CREATE TABLE covar_tab (a INT, b INT, c INT) +ROW FORMAT DELIMITED FIELDS TERMINATED BY '\t' +STORED AS TEXTFILE; +LOAD DATA LOCAL INPATH '../data/files/covar_tab.txt' OVERWRITE +INTO TABLE covar_tab; + +DESCRIBE FUNCTION cover_pop; +SELECT covar_pop(b, c) FROM covar_tab WHERE a < 1; +SELECT covar_pop(b, c) FROM covar_tab WHERE a < 3; +SELECT covar_pop(b, c) FROM covar_tab WHERE a = 3; +SELECT a, covar_pop(b, c) FROM covar_tab GROUP BY a ORDER BY a; +SELECT covar_pop(b, c) FROM covar_tab; + +DROP TABLE covar_tab; Index: ql/src/java/org/apache/hadoop/hive/ql/exec/FunctionRegistry.java =================================================================== --- ql/src/java/org/apache/hadoop/hive/ql/exec/FunctionRegistry.java (revision 984986) +++ ql/src/java/org/apache/hadoop/hive/ql/exec/FunctionRegistry.java (working copy) @@ -134,6 +134,8 @@ import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFAverage; import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFBridge; import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFCount; +import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFCovariance; +import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFCovarianceSample; import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFEvaluator; import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFHistogramNumeric; import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFCollectSet; @@ -365,6 +367,8 @@ registerGenericUDAF("variance", new GenericUDAFVariance()); registerGenericUDAF("var_pop", new GenericUDAFVariance()); registerGenericUDAF("var_samp", new GenericUDAFVarianceSample()); + registerGenericUDAF("covar_pop", new GenericUDAFCovariance()); + registerGenericUDAF("covar_samp", new GenericUDAFCovarianceSample()); registerGenericUDAF("histogram_numeric", new GenericUDAFHistogramNumeric()); registerGenericUDAF("percentile_approx", new GenericUDAFPercentileApprox()); Index: ql/src/java/org/apache/hadoop/hive/ql/udf/generic/GenericUDAFCovarianceSample.java =================================================================== --- ql/src/java/org/apache/hadoop/hive/ql/udf/generic/GenericUDAFCovarianceSample.java (revision 0) +++ ql/src/java/org/apache/hadoop/hive/ql/udf/generic/GenericUDAFCovarianceSample.java (revision 0) @@ -0,0 +1,113 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.hadoop.hive.ql.udf.generic; + +import org.apache.hadoop.hive.ql.exec.Description; +import org.apache.hadoop.hive.ql.exec.UDFArgumentTypeException; +import org.apache.hadoop.hive.ql.metadata.HiveException; +import org.apache.hadoop.hive.ql.parse.SemanticException; +import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector; +import org.apache.hadoop.hive.serde2.typeinfo.PrimitiveTypeInfo; +import org.apache.hadoop.hive.serde2.typeinfo.TypeInfo; + +/** + * Compute the sample covariance by extending GenericUDAFCovariance and overriding + * the terminate() method of the evaluator. + * + */ +@Description(name = "covar_samp", + value = "_FUNC_(x,y) - Returns the sample covariance of a set of number pairs") +public class GenericUDAFCovarianceSample extends GenericUDAFCovariance { + + @Override + public GenericUDAFEvaluator getEvaluator(TypeInfo[] parameters) + throws SemanticException { + if (parameters.length != 2) { + throw new UDFArgumentTypeException(parameters.length - 1, + "Exactly two arguments are expected."); + } + + if (parameters[0].getCategory() != ObjectInspector.Category.PRIMITIVE) { + throw new UDFArgumentTypeException(0, + "Only primitive type arguments are accepted but " + + parameters[0].getTypeName() + " is passed."); + } + + if (parameters[1].getCategory() != ObjectInspector.Category.PRIMITIVE) { + throw new UDFArgumentTypeException(1, + "Only primitive type arguments are accepted but " + + parameters[1].getTypeName() + " is passed."); + } + + switch (((PrimitiveTypeInfo) parameters[0]).getPrimitiveCategory()) { + case BYTE: + case SHORT: + case INT: + case LONG: + case FLOAT: + case DOUBLE: + case STRING: + switch (((PrimitiveTypeInfo) parameters[1]).getPrimitiveCategory()) { + case BYTE: + case SHORT: + case INT: + case LONG: + case FLOAT: + case DOUBLE: + case STRING: + return new GenericUDAFCovarianceSampleEvaluator(); + case BOOLEAN: + default: + throw new UDFArgumentTypeException(1, + "Only numeric or string type arguments are accepted but " + + parameters[1].getTypeName() + " is passed."); + } + case BOOLEAN: + default: + throw new UDFArgumentTypeException(0, + "Only numeric or string type arguments are accepted but " + + parameters[0].getTypeName() + " is passed."); + } + } + + /** + * Compute the sample covariance by extending GenericUDAFCovarianceEvaluator and + * overriding the terminate() method of the evaluator. + */ + public static class GenericUDAFCovarianceSampleEvaluator extends + GenericUDAFCovarianceEvaluator { + + @Override + public Object terminate(AggregationBuffer agg) throws HiveException { + StdAgg myagg = (StdAgg) agg; + + if (myagg.count == 0) { // SQL standard - return null for zero elements + return null; + } + else { + if (myagg.count > 1) { + getResult().set(myagg.covar / (myagg.count - 1)); + } else { // for one element the variance is always 0 + getResult().set(0); + } + return getResult(); + } + } + } + +} Index: ql/src/java/org/apache/hadoop/hive/ql/udf/generic/GenericUDAFCovariance.java =================================================================== --- ql/src/java/org/apache/hadoop/hive/ql/udf/generic/GenericUDAFCovariance.java (revision 0) +++ ql/src/java/org/apache/hadoop/hive/ql/udf/generic/GenericUDAFCovariance.java (revision 0) @@ -0,0 +1,344 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.hadoop.hive.ql.udf.generic; + +import java.util.ArrayList; + +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; +import org.apache.hadoop.hive.ql.exec.Description; +import org.apache.hadoop.hive.ql.exec.UDFArgumentTypeException; +import org.apache.hadoop.hive.ql.metadata.HiveException; +import org.apache.hadoop.hive.ql.parse.SemanticException; +import org.apache.hadoop.hive.serde2.io.DoubleWritable; +import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory; +import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.StructField; +import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.primitive.DoubleObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.primitive.LongObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory; +import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorUtils; +import org.apache.hadoop.hive.serde2.typeinfo.PrimitiveTypeInfo; +import org.apache.hadoop.hive.serde2.typeinfo.TypeInfo; +import org.apache.hadoop.io.LongWritable; +import org.apache.hadoop.util.StringUtils; + +/** + * Compute the covariance covar_pop(x, y), using the following one-pass method + * (ref. "Formulas for Robust, One-Pass Parallel Computation of Covariances and + * Arbitrary-Order Statistical Moments", Philippe Pebay, Sandia Labs): + * + * Incremental: + * n : + * mx_n = mx_(n-1) + [x_n - mx_(n-1)]/n : + * my_n = my_(n-1) + [y_n - my_(n-1)]/n : + * c_n = c_(n-1) + (x_n - mx_(n-1))*(y_n - my_n) : + * + * Merge: + * c_X = c_A + c_B + (mx_A - mx_B)*(my_A - my_B)*n_A*n_B/n_X + * + */ +@Description(name = "covariance,covar_pop", + value = "_FUNC_(x,y) - Returns the covariance of a set of number pairs") +public class GenericUDAFCovariance extends AbstractGenericUDAFResolver { + + static final Log LOG = LogFactory.getLog(GenericUDAFCovariance.class.getName()); + + @Override + public GenericUDAFEvaluator getEvaluator(TypeInfo[] parameters) throws SemanticException { + if (parameters.length != 2) { + throw new UDFArgumentTypeException(parameters.length - 1, + "Exactly two arguments are expected."); + } + + if (parameters[0].getCategory() != ObjectInspector.Category.PRIMITIVE) { + throw new UDFArgumentTypeException(0, + "Only primitive type arguments are accepted but " + + parameters[0].getTypeName() + " is passed."); + } + + if (parameters[1].getCategory() != ObjectInspector.Category.PRIMITIVE) { + throw new UDFArgumentTypeException(1, + "Only primitive type arguments are accepted but " + + parameters[1].getTypeName() + " is passed."); + } + + switch (((PrimitiveTypeInfo) parameters[0]).getPrimitiveCategory()) { + case BYTE: + case SHORT: + case INT: + case LONG: + case FLOAT: + case DOUBLE: + case STRING: + switch (((PrimitiveTypeInfo) parameters[1]).getPrimitiveCategory()) { + case BYTE: + case SHORT: + case INT: + case LONG: + case FLOAT: + case DOUBLE: + case STRING: + return new GenericUDAFCovarianceEvaluator(); + case BOOLEAN: + default: + throw new UDFArgumentTypeException(1, + "Only numeric or string type arguments are accepted but " + + parameters[1].getTypeName() + " is passed."); + } + case BOOLEAN: + default: + throw new UDFArgumentTypeException(0, + "Only numeric or string type arguments are accepted but " + + parameters[0].getTypeName() + " is passed."); + } + } + + /** + * Evaluate the variance using the algorithm described in + * http://en.wikipedia.org/wiki/Algorithms_for_calculating_variance, + * presumably by Pébay, Philippe (2008), in "Formulas for Robust, + * One-Pass Parallel Computation of Covariances and Arbitrary-Order + * Statistical Moments", Technical Report SAND2008-6212, + * Sandia National Laboratories, + * http://infoserve.sandia.gov/sand_doc/2008/086212.pdf + * + * Incremental: + * n : + * mx_n = mx_(n-1) + [x_n - mx_(n-1)]/n : + * my_n = my_(n-1) + [y_n - my_(n-1)]/n : + * c_n = c_(n-1) + (x_n - mx_(n-1))*(y_n - my_n) : + * + * Merge: + * c_X = c_A + c_B + (mx_A - mx_B)*(my_A - my_B)*n_A*n_B/n_X + * + * This one-pass algorithm is stable. + * + */ + public static class GenericUDAFCovarianceEvaluator extends GenericUDAFEvaluator { + + // For PARTIAL1 and COMPLETE + private PrimitiveObjectInspector xInputOI; + private PrimitiveObjectInspector yInputOI; + + // For PARTIAL2 and FINAL + private StructObjectInspector soi; + private StructField countField; + private StructField xavgField; + private StructField yavgField; + private StructField covarField; + private LongObjectInspector countFieldOI; + private DoubleObjectInspector xavgFieldOI; + private DoubleObjectInspector yavgFieldOI; + private DoubleObjectInspector covarFieldOI; + + // For PARTIAL1 and PARTIAL2 + private Object[] partialResult; + + // For FINAL and COMPLETE + private DoubleWritable result; + + @Override + public ObjectInspector init(Mode m, ObjectInspector[] parameters) throws HiveException { +// assert (parameters.length == 2); + super.init(m, parameters); + + // init input + if (mode == Mode.PARTIAL1 || mode == Mode.COMPLETE) { + xInputOI = (PrimitiveObjectInspector) parameters[0]; + yInputOI = (PrimitiveObjectInspector) parameters[1]; + } + else { + soi = (StructObjectInspector) parameters[0]; + + countField = soi.getStructFieldRef("count"); + xavgField = soi.getStructFieldRef("xavg"); + yavgField = soi.getStructFieldRef("yavg"); + covarField = soi.getStructFieldRef("covar"); + + countFieldOI = + (LongObjectInspector) countField.getFieldObjectInspector(); + xavgFieldOI = + (DoubleObjectInspector) xavgField.getFieldObjectInspector(); + yavgFieldOI = + (DoubleObjectInspector) yavgField.getFieldObjectInspector(); + covarFieldOI = + (DoubleObjectInspector) covarField.getFieldObjectInspector(); + } + + // init output + if (mode == Mode.PARTIAL1 || mode == Mode.PARTIAL2) { + // The output of a partial aggregation is a struct containing + // a long count, two double averages, and a double covariance. + + ArrayList foi = new ArrayList(); + + foi.add(PrimitiveObjectInspectorFactory.writableLongObjectInspector); + foi.add(PrimitiveObjectInspectorFactory.writableDoubleObjectInspector); + foi.add(PrimitiveObjectInspectorFactory.writableDoubleObjectInspector); + foi.add(PrimitiveObjectInspectorFactory.writableDoubleObjectInspector); + + ArrayList fname = new ArrayList(); + fname.add("count"); + fname.add("xavg"); + fname.add("yavg"); + fname.add("covar"); + + partialResult = new Object[4]; + partialResult[0] = new LongWritable(0); + partialResult[1] = new DoubleWritable(0); + partialResult[2] = new DoubleWritable(0); + partialResult[3] = new DoubleWritable(0); + + return ObjectInspectorFactory.getStandardStructObjectInspector(fname, + foi); + + } + else { + setResult(new DoubleWritable(0)); + return PrimitiveObjectInspectorFactory.writableDoubleObjectInspector; + } + } + + static class StdAgg implements AggregationBuffer { + long count; // number n of elements + double xavg; // average of x elements + double yavg; // average of y elements + double covar; // n times the covariance) + }; + + @Override + public AggregationBuffer getNewAggregationBuffer() throws HiveException { + StdAgg result = new StdAgg(); + reset(result); + return result; + } + + @Override + public void reset(AggregationBuffer agg) throws HiveException { + StdAgg myagg = (StdAgg) agg; + myagg.count = 0; + myagg.xavg = 0; + myagg.yavg = 0; + myagg.covar = 0; + } + + private boolean warned = false; + + @Override + public void iterate(AggregationBuffer agg, Object[] parameters) + throws HiveException { + assert (parameters.length == 2); + Object px = parameters[0]; + Object py = parameters[1]; + if (px != null && py != null) { + StdAgg myagg = (StdAgg) agg; + try { + double vx = PrimitiveObjectInspectorUtils.getDouble(px, xInputOI); + double vy = PrimitiveObjectInspectorUtils.getDouble(py, yInputOI); + myagg.count++; + myagg.yavg = myagg.yavg + (vy - myagg.yavg)/myagg.count; + if( myagg.count > 1 ) { + myagg.covar += (vx - myagg.xavg)*(vy - myagg.yavg); + } + myagg.xavg = myagg.xavg + (vx - myagg.xavg)/myagg.count; + } + catch (NumberFormatException e) { + if (!warned) { + warned = true; + LOG.warn(getClass().getSimpleName() + " " + + StringUtils.stringifyException(e)); + LOG.warn(getClass().getSimpleName() + + " ignoring similar exceptions."); + } + } + } + } + + @Override + public Object terminatePartial(AggregationBuffer agg) throws HiveException { + StdAgg myagg = (StdAgg) agg; + ((LongWritable) partialResult[0]).set(myagg.count); + ((DoubleWritable) partialResult[1]).set(myagg.xavg); + ((DoubleWritable) partialResult[2]).set(myagg.yavg); + ((DoubleWritable) partialResult[3]).set(myagg.covar); + return partialResult; + } + + @Override + public void merge(AggregationBuffer agg, Object partial) throws HiveException { + if (partial != null) { + StdAgg myagg = (StdAgg) agg; + + Object partialCount = soi.getStructFieldData(partial, countField); + Object partialXAvg = soi.getStructFieldData(partial, xavgField); + Object partialYAvg = soi.getStructFieldData(partial, yavgField); + Object partialCovar = soi.getStructFieldData(partial, covarField); + + long nA = myagg.count; + long nB = countFieldOI.get(partialCount); + + if (nA == 0) { + // Just copy the information since there is nothing so far + myagg.count = countFieldOI.get(partialCount); + myagg.xavg = xavgFieldOI.get(partialXAvg); + myagg.yavg = yavgFieldOI.get(partialYAvg); + myagg.covar = covarFieldOI.get(partialCovar); + } + + if (nA != 0 && nB != 0) { + // Merge the two partials + double xavgA = myagg.xavg; + double yavgA = myagg.yavg; + double xavgB = xavgFieldOI.get(partialXAvg); + double yavgB = yavgFieldOI.get(partialYAvg); + double covarB = covarFieldOI.get(partialCovar); + + myagg.count += nB; + myagg.xavg = (xavgA*nA + xavgB*nB)/myagg.count; + myagg.yavg = (yavgA*nA + yavgB*nB)/myagg.count; + myagg.covar += covarB + (xavgA - xavgB)*(yavgA - yavgB)*((double)(nA*nB)/myagg.count); + } + } + } + + @Override + public Object terminate(AggregationBuffer agg) throws HiveException { + StdAgg myagg = (StdAgg) agg; + + if (myagg.count == 0) { // SQL standard - return null for zero elements + return null; + } + else { + getResult().set(myagg.covar / (myagg.count)); + return getResult(); + } + } + + public void setResult(DoubleWritable result) { + this.result = result; + } + + public DoubleWritable getResult() { + return result; + } + } + +}