diff --git ql/src/java/org/apache/hadoop/hive/ql/udf/generic/GenericUDAFCorrelation.java ql/src/java/org/apache/hadoop/hive/ql/udf/generic/GenericUDAFCorrelation.java index 5694e24..8056931 100644 --- ql/src/java/org/apache/hadoop/hive/ql/udf/generic/GenericUDAFCorrelation.java +++ ql/src/java/org/apache/hadoop/hive/ql/udf/generic/GenericUDAFCorrelation.java @@ -19,8 +19,6 @@ import java.util.ArrayList; -import org.apache.commons.logging.Log; -import org.apache.commons.logging.LogFactory; import org.apache.hadoop.hive.ql.exec.Description; import org.apache.hadoop.hive.ql.exec.UDFArgumentTypeException; import org.apache.hadoop.hive.ql.metadata.HiveException; @@ -74,8 +72,6 @@ + "and STDDEV_POP is the population standard deviation.") public class GenericUDAFCorrelation extends AbstractGenericUDAFResolver { - static final Log LOG = LogFactory.getLog(GenericUDAFCorrelation.class.getName()); - @Override public GenericUDAFEvaluator getEvaluator(TypeInfo[] parameters) throws SemanticException { if (parameters.length != 2) { @@ -289,15 +285,15 @@ public void iterate(AggregationBuffer agg, Object[] parameters) throws HiveExcep StdAgg myagg = (StdAgg) agg; double vx = PrimitiveObjectInspectorUtils.getDouble(px, xInputOI); double vy = PrimitiveObjectInspectorUtils.getDouble(py, yInputOI); - double xavgOld = myagg.xavg; - double yavgOld = myagg.yavg; + double deltaX = vx - myagg.xavg; + double deltaY = vy - myagg.yavg; myagg.count++; - myagg.xavg += (vx - xavgOld) / myagg.count; - myagg.yavg += (vy - yavgOld) / myagg.count; + myagg.xavg += deltaX / myagg.count; + myagg.yavg += deltaY / myagg.count; if (myagg.count > 1) { - myagg.covar += (vx - xavgOld) * (vy - myagg.yavg); - myagg.xvar += (vx - xavgOld) * (vx - myagg.xavg); - myagg.yvar += (vy - yavgOld) * (vy - myagg.yavg); + myagg.covar += deltaX * (vy - myagg.yavg); + myagg.xvar += deltaX * (vx - myagg.xavg); + myagg.yvar += deltaY * (vy - myagg.yavg); } } } @@ -352,8 +348,8 @@ public void merge(AggregationBuffer agg, Object partial) throws HiveException { myagg.count += nB; myagg.xavg = (xavgA * nA + xavgB * nB) / myagg.count; myagg.yavg = (yavgA * nA + yavgB * nB) / myagg.count; - myagg.xvar += xvarB + (xavgA - xavgB) * (xavgA - xavgB) * myagg.count; - myagg.yvar += yvarB + (yavgA - yavgB) * (yavgA - yavgB) * myagg.count; + myagg.xvar += xvarB + (xavgA - xavgB) * (xavgA - xavgB) * nA * nB / myagg.count; + myagg.yvar += yvarB + (yavgA - yavgB) * (yavgA - yavgB) * nA * nB / myagg.count; myagg.covar += covarB + (xavgA - xavgB) * (yavgA - yavgB) * ((double) (nA * nB) / myagg.count); } diff --git ql/src/test/org/apache/hadoop/hive/ql/udf/generic/TestGenericUDAFCorrelation.java ql/src/test/org/apache/hadoop/hive/ql/udf/generic/TestGenericUDAFCorrelation.java new file mode 100644 index 0000000..ad29b88 --- /dev/null +++ ql/src/test/org/apache/hadoop/hive/ql/udf/generic/TestGenericUDAFCorrelation.java @@ -0,0 +1,67 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.hadoop.hive.ql.udf.generic; + +import junit.framework.TestCase; +import org.apache.hadoop.hive.ql.metadata.HiveException; +import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory; +import org.apache.hadoop.hive.serde2.typeinfo.TypeInfo; +import org.apache.hadoop.hive.serde2.typeinfo.TypeInfoFactory; + +public class TestGenericUDAFCorrelation extends TestCase { + + public void testCorr() throws HiveException { + GenericUDAFCorrelation corr = new GenericUDAFCorrelation(); + GenericUDAFEvaluator eval1 = corr.getEvaluator( + new TypeInfo[]{TypeInfoFactory.doubleTypeInfo,TypeInfoFactory.doubleTypeInfo }); + GenericUDAFEvaluator eval2 = corr.getEvaluator( + new TypeInfo[]{TypeInfoFactory.doubleTypeInfo,TypeInfoFactory.doubleTypeInfo }); + + ObjectInspector poi1 = eval1.init(GenericUDAFEvaluator.Mode.PARTIAL1, + new ObjectInspector[] {PrimitiveObjectInspectorFactory.javaDoubleObjectInspector, + PrimitiveObjectInspectorFactory.javaDoubleObjectInspector}); + ObjectInspector poi2 = eval2.init(GenericUDAFEvaluator.Mode.PARTIAL1, + new ObjectInspector[] {PrimitiveObjectInspectorFactory.javaDoubleObjectInspector, + PrimitiveObjectInspectorFactory.javaDoubleObjectInspector}); + + GenericUDAFEvaluator.AggregationBuffer buffer1 = eval1.getNewAggregationBuffer(); + eval1.iterate(buffer1, new Object[]{100d, 200d}); + eval1.iterate(buffer1, new Object[]{150d, 210d}); + eval1.iterate(buffer1, new Object[]{200d, 220d}); + Object object1 = eval1.terminatePartial(buffer1); + + GenericUDAFEvaluator.AggregationBuffer buffer2 = eval2.getNewAggregationBuffer(); + eval2.iterate(buffer2, new Object[]{250d, 230d}); + eval2.iterate(buffer2, new Object[]{250d, 240d}); + eval2.iterate(buffer2, new Object[]{300d, 250d}); + eval2.iterate(buffer2, new Object[]{350d, 260d}); + Object object2 = eval2.terminatePartial(buffer2); + + ObjectInspector coi = eval2.init(GenericUDAFEvaluator.Mode.FINAL, + new ObjectInspector[]{poi1}); + + GenericUDAFEvaluator.AggregationBuffer buffer3 = eval2.getNewAggregationBuffer(); + eval2.merge(buffer3, object1); + eval2.merge(buffer3, object2); + + Object result = eval2.terminate(buffer3); + assertEquals("0.987829161147262", String.valueOf(result)); + } +}