diff --git common/src/java/org/apache/hive/common/util/HiveStringUtils.java common/src/java/org/apache/hive/common/util/HiveStringUtils.java index c21c937..346f4b8 100644 --- common/src/java/org/apache/hive/common/util/HiveStringUtils.java +++ common/src/java/org/apache/hive/common/util/HiveStringUtils.java @@ -826,4 +826,15 @@ public static int getTextUtfLength(Text t) { } return len; } + + public static int getUnpaddedLength(Text t) { + byte[] bytes = t.getBytes(); + int offset = t.getLength() - 1; + for (; offset >= 0; offset--) { + if (bytes[offset] != ' ') { + break; + } + } + return offset + 1; + } } diff --git ql/src/java/org/apache/hadoop/hive/ql/exec/GroupByOperator.java ql/src/java/org/apache/hadoop/hive/ql/exec/GroupByOperator.java index 792d87f..77ac7fb 100644 --- ql/src/java/org/apache/hadoop/hive/ql/exec/GroupByOperator.java +++ ql/src/java/org/apache/hadoop/hive/ql/exec/GroupByOperator.java @@ -626,9 +626,9 @@ protected void updateAggregations(AggregationBuffer[] aggs, Object row, if (lastInvoke[ai] == null) { lastInvoke[ai] = new Object[o.length]; } - if (ObjectInspectorUtils.compare(o, + if (!ObjectInspectorUtils.equals(o, aggregationParameterObjectInspectors[ai], lastInvoke[ai], - aggregationParameterStandardObjectInspectors[ai]) != 0) { + aggregationParameterStandardObjectInspectors[ai])) { aggregationEvaluators[ai].aggregate(aggs[ai], o); for (int pi = 0; pi < o.length; pi++) { lastInvoke[ai][pi] = ObjectInspectorUtils.copyToStandardObject( @@ -676,10 +676,10 @@ protected void updateAggregations(AggregationBuffer[] aggs, Object row, if (lastInvoke[i] == null) { lastInvoke[i] = new Object[o.length]; } - if (ObjectInspectorUtils.compare(o, + if (!ObjectInspectorUtils.equals(o, aggregationParameterObjectInspectors[i], lastInvoke[i], - aggregationParameterStandardObjectInspectors[i]) != 0) { + aggregationParameterStandardObjectInspectors[i])) { aggregationEvaluators[i].aggregate(aggs[i], o); for (int pi = 0; pi < o.length; pi++) { lastInvoke[i][pi] = ObjectInspectorUtils.copyToStandardObject( diff --git ql/src/java/org/apache/hadoop/hive/ql/udf/generic/GenericUDFArrayContains.java ql/src/java/org/apache/hadoop/hive/ql/udf/generic/GenericUDFArrayContains.java index 510f367..e961f95 100644 --- ql/src/java/org/apache/hadoop/hive/ql/udf/generic/GenericUDFArrayContains.java +++ ql/src/java/org/apache/hadoop/hive/ql/udf/generic/GenericUDFArrayContains.java @@ -117,8 +117,8 @@ public Object evaluate(DeferredObject[] arguments) throws HiveException { for (int i=0; i { @@ -86,6 +85,10 @@ public Text getPaddedValue() { return getTextValue(); } + public int getUnpaddedLength() { + return HiveStringUtils.getUnpaddedLength(getStrippedValue()); + } + public int getCharacterLength() { return HiveStringUtils.getTextUtfLength(getStrippedValue()); } diff --git serde/src/java/org/apache/hadoop/hive/serde2/io/HiveVarcharWritable.java serde/src/java/org/apache/hadoop/hive/serde2/io/HiveVarcharWritable.java index a165b84..8e15d70 100644 --- serde/src/java/org/apache/hadoop/hive/serde2/io/HiveVarcharWritable.java +++ serde/src/java/org/apache/hadoop/hive/serde2/io/HiveVarcharWritable.java @@ -17,14 +17,8 @@ */ package org.apache.hadoop.hive.serde2.io; -import java.io.DataInput; -import java.io.DataOutput; -import java.io.IOException; - import org.apache.hadoop.hive.common.type.HiveBaseChar; import org.apache.hadoop.hive.common.type.HiveVarchar; -import org.apache.hadoop.hive.shims.ShimLoader; -import org.apache.hadoop.io.Text; import org.apache.hadoop.io.WritableComparable; public class HiveVarcharWritable extends HiveBaseCharWritable diff --git serde/src/java/org/apache/hadoop/hive/serde2/objectinspector/ListObjectsEqualComparer.java serde/src/java/org/apache/hadoop/hive/serde2/objectinspector/ListObjectsEqualComparer.java index ed4979e..7a70695 100644 --- serde/src/java/org/apache/hadoop/hive/serde2/objectinspector/ListObjectsEqualComparer.java +++ serde/src/java/org/apache/hadoop/hive/serde2/objectinspector/ListObjectsEqualComparer.java @@ -124,8 +124,7 @@ public boolean areEqual(Object o0, Object o1) { return (soi0.getPrimitiveJavaObject(o0).equals( soi1.getPrimitiveJavaObject(o1))); default: - return (ObjectInspectorUtils.compare( - o0, oi0, o1, oi1) == 0); + return ObjectInspectorUtils.equals(o0, oi0, o1, oi1); } } } diff --git serde/src/java/org/apache/hadoop/hive/serde2/objectinspector/MapEqualComparer.java serde/src/java/org/apache/hadoop/hive/serde2/objectinspector/MapEqualComparer.java index adde408..5fdd282 100644 --- serde/src/java/org/apache/hadoop/hive/serde2/objectinspector/MapEqualComparer.java +++ serde/src/java/org/apache/hadoop/hive/serde2/objectinspector/MapEqualComparer.java @@ -18,6 +18,11 @@ package org.apache.hadoop.hive.serde2.objectinspector; public interface MapEqualComparer { + + public static final MapEqualComparer FULL = new FullMapEqualComparer(); + public static final MapEqualComparer CROSS = new CrossMapEqualComparer(); + public static final MapEqualComparer SIMPLE = new SimpleMapEqualComparer(); + /* * Compare the two map objects for equality. */ diff --git serde/src/java/org/apache/hadoop/hive/serde2/objectinspector/ObjectInspectorUtils.java serde/src/java/org/apache/hadoop/hive/serde2/objectinspector/ObjectInspectorUtils.java index 1baf359..9149922 100644 --- serde/src/java/org/apache/hadoop/hive/serde2/objectinspector/ObjectInspectorUtils.java +++ serde/src/java/org/apache/hadoop/hive/serde2/objectinspector/ObjectInspectorUtils.java @@ -69,9 +69,7 @@ import org.apache.hadoop.hive.serde2.objectinspector.primitive.StringObjectInspector; import org.apache.hadoop.hive.serde2.objectinspector.primitive.TimestampObjectInspector; import org.apache.hadoop.hive.serde2.objectinspector.primitive.WritableStringObjectInspector; -import org.apache.hadoop.hive.serde2.typeinfo.PrimitiveTypeInfo; import org.apache.hadoop.hive.serde2.typeinfo.TypeInfoUtils; -import org.apache.hadoop.hive.shims.ShimLoader; import org.apache.hadoop.io.BytesWritable; import org.apache.hadoop.io.Text; import org.apache.hadoop.util.StringUtils; @@ -108,7 +106,7 @@ public static ObjectInspector getWritableObjectInspector(ObjectInspector oi) { PrimitiveObjectInspector poi = (PrimitiveObjectInspector) oi; if (!(poi instanceof AbstractPrimitiveWritableObjectInspector)) { return PrimitiveObjectInspectorFactory.getPrimitiveWritableObjectInspector( - (PrimitiveTypeInfo)poi.getTypeInfo()); + poi.getTypeInfo()); } } return oi; @@ -601,19 +599,19 @@ public static int hashCode(Object o, ObjectInspector objIns) { * Compare two arrays of objects with their respective arrays of * ObjectInspectors. */ - public static int compare(Object[] o1, ObjectInspector[] oi1, Object[] o2, + public static boolean equals(Object[] o1, ObjectInspector[] oi1, Object[] o2, ObjectInspector[] oi2) { assert (o1.length == oi1.length); assert (o2.length == oi2.length); - assert (o1.length == o2.length); - + if (o1.length != o2.length) { + return false; + } for (int i = 0; i < o1.length; i++) { - int r = compare(o1[i], oi1[i], o2[i], oi2[i]); - if (r != 0) { - return r; + if (!equals(o1[i], oi1[i], o2[i], oi2[i])) { + return false; } } - return 0; + return true; } /** @@ -656,7 +654,7 @@ public static boolean compareSupported(ObjectInspector oi) { */ public static int compare(Object o1, ObjectInspector oi1, Object o2, ObjectInspector oi2) { - return compare(o1, oi1, o2, oi2, new FullMapEqualComparer()); + return compare(o1, oi1, o2, oi2, MapEqualComparer.FULL); } /** @@ -834,6 +832,227 @@ public static int compare(Object o1, ObjectInspector oi1, Object o2, } } + public static boolean equals(Object o1, ObjectInspector oi1, Object o2, + ObjectInspector oi2) { + return equals(o1, oi1, o2, oi2, MapEqualComparer.FULL); + } + + public static boolean equals(Object o1, ObjectInspector oi1, Object o2, + ObjectInspector oi2, MapEqualComparer mapEqualComparer) { + if (oi1.getCategory() != oi2.getCategory()) { + return false; + } + if (o1 == null) { + return o2 == null; + } else if (o2 == null) { + return false; + } + + switch (oi1.getCategory()) { + case PRIMITIVE: { + PrimitiveObjectInspector poi1 = ((PrimitiveObjectInspector) oi1); + PrimitiveObjectInspector poi2 = ((PrimitiveObjectInspector) oi2); + if (poi1.getPrimitiveCategory() != poi2.getPrimitiveCategory()) { + return false; + } + switch (poi1.getPrimitiveCategory()) { + case VOID: + return true; + case BOOLEAN: { + boolean v1 = ((BooleanObjectInspector) poi1).get(o1); + boolean v2 = ((BooleanObjectInspector) poi2).get(o2); + return v1 == v2; + } + case BYTE: { + int v1 = ((ByteObjectInspector) poi1).get(o1); + int v2 = ((ByteObjectInspector) poi2).get(o2); + return v1 == v2; + } + case SHORT: { + int v1 = ((ShortObjectInspector) poi1).get(o1); + int v2 = ((ShortObjectInspector) poi2).get(o2); + return v1 == v2; + } + case INT: { + int v1 = ((IntObjectInspector) poi1).get(o1); + int v2 = ((IntObjectInspector) poi2).get(o2); + return v1 == v2; + } + case LONG: { + long v1 = ((LongObjectInspector) poi1).get(o1); + long v2 = ((LongObjectInspector) poi2).get(o2); + return v1 == v2; + } + case FLOAT: { + float v1 = ((FloatObjectInspector) poi1).get(o1); + float v2 = ((FloatObjectInspector) poi2).get(o2); + return v1 == v2; + } + case DOUBLE: { + double v1 = ((DoubleObjectInspector) poi1).get(o1); + double v2 = ((DoubleObjectInspector) poi2).get(o2); + return v1 == v2; + } + case STRING: { + if (poi1.preferWritable() || poi2.preferWritable()) { + Text t1 = (Text) poi1.getPrimitiveWritableObject(o1); + Text t2 = (Text) poi2.getPrimitiveWritableObject(o2); + return isEqualText(t1, t2); + } + String s1 = (String) poi1.getPrimitiveJavaObject(o1); + String s2 = (String) poi2.getPrimitiveJavaObject(o2); + return s1.equals(s2); + } + case CHAR: { + HiveCharWritable c1 = ((HiveCharObjectInspector)poi1).getPrimitiveWritableObject(o1); + HiveCharWritable c2 = ((HiveCharObjectInspector)poi2).getPrimitiveWritableObject(o2); + return isEqualChar(c1, c2); + } + case VARCHAR: { + HiveVarcharWritable t1 = ((HiveVarcharObjectInspector)poi1).getPrimitiveWritableObject(o1); + HiveVarcharWritable t2 = ((HiveVarcharObjectInspector)poi2).getPrimitiveWritableObject(o2); + return isEqualVarchar(t1, t2); + } + case BINARY: { + BytesWritable bw1 = ((BinaryObjectInspector) poi1).getPrimitiveWritableObject(o1); + BytesWritable bw2 = ((BinaryObjectInspector) poi2).getPrimitiveWritableObject(o2); + return isEqualBinary(bw1, bw2); + } + + case DATE: { + DateWritable d1 = ((DateObjectInspector) poi1) + .getPrimitiveWritableObject(o1); + DateWritable d2 = ((DateObjectInspector) poi2) + .getPrimitiveWritableObject(o2); + return d1.equals(d2); + } + case TIMESTAMP: { + TimestampWritable t1 = ((TimestampObjectInspector) poi1) + .getPrimitiveWritableObject(o1); + TimestampWritable t2 = ((TimestampObjectInspector) poi2) + .getPrimitiveWritableObject(o2); + return t1.equals(t2); + } + case DECIMAL: { + HiveDecimalWritable t1 = ((HiveDecimalObjectInspector) poi1) + .getPrimitiveWritableObject(o1); + HiveDecimalWritable t2 = ((HiveDecimalObjectInspector) poi2) + .getPrimitiveWritableObject(o2); + return t1.equals(t2); + } + default: { + throw new RuntimeException("Unknown type: " + + poi1.getPrimitiveCategory()); + } + } + } + case STRUCT: { + StructObjectInspector soi1 = (StructObjectInspector) oi1; + StructObjectInspector soi2 = (StructObjectInspector) oi2; + List fields1 = soi1.getAllStructFieldRefs(); + List fields2 = soi2.getAllStructFieldRefs(); + int size1 = fields1.size(); + int size2 = fields2.size(); + if (size1 != size2) { + return false; + } + for (int i = 0; i < size1; i++) { + if (!equals(soi1.getStructFieldData(o1, fields1.get(i)), fields1 + .get(i).getFieldObjectInspector(), soi2.getStructFieldData(o2, + fields2.get(i)), fields2.get(i).getFieldObjectInspector(), mapEqualComparer)) { + return false; + } + } + return true; + } + case LIST: { + ListObjectInspector loi1 = (ListObjectInspector) oi1; + ListObjectInspector loi2 = (ListObjectInspector) oi2; + int size1 = loi1.getListLength(o1); + int size2 = loi2.getListLength(o2); + if (size1 != size2) { + return false; + } + for (int i = 0; i < size1; i++) { + if (!equals(loi1.getListElement(o1, i), loi1 + .getListElementObjectInspector(), loi2.getListElement(o2, i), loi2 + .getListElementObjectInspector(), mapEqualComparer)) { + return false; + } + } + return true; + } + case MAP: { + if (mapEqualComparer == null) { + throw new RuntimeException("Compare on map type not supported!"); + } + MapObjectInspector moi1 = (MapObjectInspector) oi1; + MapObjectInspector moi2 = (MapObjectInspector) oi2; + int size1 = moi1.getMapSize(o1); + int size2 = moi2.getMapSize(o2); + if (size1 != size2) { + return false; + } + return mapEqualComparer.compare(o1, moi1, o2, moi2) == 0; + } + case UNION: { + UnionObjectInspector uoi1 = (UnionObjectInspector) oi1; + UnionObjectInspector uoi2 = (UnionObjectInspector) oi2; + byte tag1 = uoi1.getTag(o1); + byte tag2 = uoi2.getTag(o2); + if (tag1 != tag2) { + return false; + } + return equals(uoi1.getField(o1), + uoi1.getObjectInspectors().get(tag1), + uoi2.getField(o2), uoi2.getObjectInspectors().get(tag2), + mapEqualComparer); + } + default: + throw new RuntimeException("Compare on unknown type: " + + oi1.getCategory()); + } + } + + private static boolean isEqualVarchar(HiveVarcharWritable vc1, HiveVarcharWritable vc2) { + return isEqualText(vc1.getTextValue(), vc2.getTextValue()); + } + + private static boolean isEqualChar(HiveCharWritable c1, HiveCharWritable c2) { + int length1 = c1.getUnpaddedLength(); + int length2 = c2.getUnpaddedLength(); + return length1 == length2 && isEqualText(c1.getPaddedValue(), c2.getPaddedValue(), length1); + } + + private static boolean isEqualText(Text t1, Text t2) { + if (t1.getLength() == t2.getLength()) { + return isEqualText(t1, t2, t1.getLength()); + } + return false; + } + + private static boolean isEqualText(Text t1, Text t2, int length) { + byte[] b1 = t1.getBytes(); + byte[] b2 = t2.getBytes(); + return isEqualBinary(b1, b2, length); + } + + private static boolean isEqualBinary(BytesWritable bw1, BytesWritable bw2) { + if (bw1.getLength() == bw2.getLength()) { + return isEqualBinary(bw1.getBytes(), bw2.getBytes(), bw1.getLength()); + } + return false; + } + + private static boolean isEqualBinary(byte[] b1, byte[] b2, int length) { + for (int i = 0; i < length; i++) { + if (b1[i] != b2[i]) { + return false; + } + } + return true; + } + /** * Get the list of field names as csv from a StructObjectInspector. */ diff --git serde/src/test/org/apache/hadoop/hive/serde2/binarysortable/TestBinarySortableSerDe.java serde/src/test/org/apache/hadoop/hive/serde2/binarysortable/TestBinarySortableSerDe.java index cefb72e..7b6bdc0 100644 --- serde/src/test/org/apache/hadoop/hive/serde2/binarysortable/TestBinarySortableSerDe.java +++ serde/src/test/org/apache/hadoop/hive/serde2/binarysortable/TestBinarySortableSerDe.java @@ -112,7 +112,7 @@ private void testBinarySortableSerDe(Object[] rows, ObjectInspector rowOI, Object[] deserialized = new Object[rows.length]; for (int i = 0; i < rows.length; i++) { deserialized[i] = serde.deserialize(bytes[i]); - if (0 != ObjectInspectorUtils.compare(rows[i], rowOI, deserialized[i], + if (!ObjectInspectorUtils.equals(rows[i], rowOI, deserialized[i], serdeOI)) { System.out.println("structs[" + i + "] = " + SerDeUtils.getJSONString(rows[i], rowOI)); diff --git serde/src/test/org/apache/hadoop/hive/serde2/lazybinary/TestLazyBinarySerDe.java serde/src/test/org/apache/hadoop/hive/serde2/lazybinary/TestLazyBinarySerDe.java index 02ae6f8..7ec3e65 100644 --- serde/src/test/org/apache/hadoop/hive/serde2/lazybinary/TestLazyBinarySerDe.java +++ serde/src/test/org/apache/hadoop/hive/serde2/lazybinary/TestLazyBinarySerDe.java @@ -129,7 +129,7 @@ private void testLazyBinarySerDe(Object[] rows, ObjectInspector rowOI, Object[] deserialized = new Object[rows.length]; for (int i = 0; i < rows.length; i++) { deserialized[i] = serde.deserialize(bytes[i]); - if (0 != ObjectInspectorUtils.compare(rows[i], rowOI, deserialized[i], + if (!ObjectInspectorUtils.equals(rows[i], rowOI, deserialized[i], serdeOI)) { System.out.println("structs[" + i + "] = " + SerDeUtils.getJSONString(rows[i], rowOI)); @@ -462,9 +462,9 @@ void testLazyBinaryMap(Random r) throws Throwable { boolean bEqual = false; for (Map.Entry entryoutput : outputmp.entrySet()) { // find the same key - if (0 == ObjectInspectorUtils.compare(entryoutput.getKey(), + if (ObjectInspectorUtils.equals(entryoutput.getKey(), lazympkeyoi, entryinput.getKey(), inputmpkeyoi)) { - if (0 != ObjectInspectorUtils.compare(entryoutput.getValue(), + if (!ObjectInspectorUtils.equals(entryoutput.getValue(), lazympvalueoi, entryinput.getValue(), inputmpvalueoi)) { assertEquals(entryoutput.getValue(), entryinput.getValue()); } else {