diff --git a/ql/src/java/org/apache/hadoop/hive/ql/exec/vector/expressions/VectorInBloomFilterColDynamicValue.java b/ql/src/java/org/apache/hadoop/hive/ql/exec/vector/expressions/VectorInBloomFilterColDynamicValue.java index 188a87e..1132dc1 100644 --- a/ql/src/java/org/apache/hadoop/hive/ql/exec/vector/expressions/VectorInBloomFilterColDynamicValue.java +++ b/ql/src/java/org/apache/hadoop/hive/ql/exec/vector/expressions/VectorInBloomFilterColDynamicValue.java @@ -33,7 +33,7 @@ import org.apache.hadoop.hive.ql.exec.vector.VectorizedRowBatch; import org.apache.hadoop.hive.ql.plan.DynamicValue; import org.apache.hadoop.hive.serde2.objectinspector.primitive.BinaryObjectInspector; -import org.apache.hive.common.util.BloomFilter; +import org.apache.hive.common.util.Bloom1Filter; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -46,7 +46,7 @@ protected int colNum; protected DynamicValue bloomFilterDynamicValue; protected transient boolean initialized = false; - protected transient BloomFilter bloomFilter; + protected transient Bloom1Filter bloomFilter; protected transient BloomFilterCheck bfCheck; public VectorInBloomFilterColDynamicValue(int colNum, DynamicValue bloomFilterDynamicValue) { @@ -95,7 +95,7 @@ private void initValue() { if (val != null) { BinaryObjectInspector boi = (BinaryObjectInspector) bloomFilterDynamicValue.getObjectInspector(); byte[] bytes = boi.getPrimitiveJavaObject(val); - bloomFilter = BloomFilter.deserialize(new ByteArrayInputStream(bytes)); + bloomFilter = Bloom1Filter.deserialize(new ByteArrayInputStream(bytes)); } else { bloomFilter = null; } diff --git a/ql/src/java/org/apache/hadoop/hive/ql/exec/vector/expressions/aggregates/VectorUDAFBloomFilter.java b/ql/src/java/org/apache/hadoop/hive/ql/exec/vector/expressions/aggregates/VectorUDAFBloomFilter.java index 4b3eca09..4df19ac 100644 --- a/ql/src/java/org/apache/hadoop/hive/ql/exec/vector/expressions/aggregates/VectorUDAFBloomFilter.java +++ b/ql/src/java/org/apache/hadoop/hive/ql/exec/vector/expressions/aggregates/VectorUDAFBloomFilter.java @@ -20,9 +20,9 @@ import java.io.ByteArrayOutputStream; import java.io.IOException; -import java.util.Arrays; import org.apache.hadoop.hive.common.type.HiveDecimal; +import org.apache.hadoop.hive.llap.LlapUtil; import org.apache.hadoop.hive.ql.exec.vector.BytesColumnVector; import org.apache.hadoop.hive.ql.exec.vector.ColumnVector; import org.apache.hadoop.hive.ql.exec.vector.DecimalColumnVector; @@ -33,18 +33,15 @@ import org.apache.hadoop.hive.ql.exec.vector.VectorExpressionDescriptor; import org.apache.hadoop.hive.ql.exec.vector.VectorizedRowBatch; import org.apache.hadoop.hive.ql.exec.vector.expressions.VectorExpression; -import org.apache.hadoop.hive.ql.exec.vector.expressions.aggregates.VectorAggregateExpression.AggregationBuffer; import org.apache.hadoop.hive.ql.metadata.HiveException; import org.apache.hadoop.hive.ql.plan.AggregationDesc; -import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFBloomFilter; -import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFEvaluator; import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFBloomFilter.GenericUDAFBloomFilterEvaluator; +import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFEvaluator; import org.apache.hadoop.hive.ql.util.JavaDataModel; import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector; import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory; import org.apache.hadoop.io.BytesWritable; -import org.apache.hadoop.io.Text; -import org.apache.hive.common.util.BloomFilter; +import org.apache.hive.common.util.Bloom1Filter; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -66,10 +63,11 @@ private static final class Aggregation implements AggregationBuffer { private static final long serialVersionUID = 1L; - BloomFilter bf; + Bloom1Filter bf; public Aggregation(long expectedEntries) { - bf = new BloomFilter(expectedEntries); + bf = new Bloom1Filter(expectedEntries); + LOG.info("Bloomfilter initialized with expectedEntries: {} sizeInBytes: {}", expectedEntries, bf.sizeInBytes()); } @Override @@ -363,7 +361,8 @@ public Object evaluateOutput(AggregationBuffer agg) throws HiveException { try { Aggregation bfAgg = (Aggregation) agg; byteStream.reset(); - BloomFilter.serialize(byteStream, bfAgg.bf); + Bloom1Filter.serialize(byteStream, bfAgg.bf); + LOG.info("Bloomfilter serialized size: {}", LlapUtil.humanReadableByteCount(byteStream.size())); byte[] bytes = byteStream.toByteArray(); bw.set(bytes, 0, bytes.length); return bw; diff --git a/ql/src/java/org/apache/hadoop/hive/ql/exec/vector/expressions/aggregates/VectorUDAFBloomFilterMerge.java b/ql/src/java/org/apache/hadoop/hive/ql/exec/vector/expressions/aggregates/VectorUDAFBloomFilterMerge.java index 67a7c50..0c43e2e 100644 --- a/ql/src/java/org/apache/hadoop/hive/ql/exec/vector/expressions/aggregates/VectorUDAFBloomFilterMerge.java +++ b/ql/src/java/org/apache/hadoop/hive/ql/exec/vector/expressions/aggregates/VectorUDAFBloomFilterMerge.java @@ -21,12 +21,12 @@ import java.io.ByteArrayOutputStream; import java.util.Arrays; +import org.apache.hadoop.hive.llap.LlapUtil; import org.apache.hadoop.hive.ql.exec.vector.BytesColumnVector; import org.apache.hadoop.hive.ql.exec.vector.ColumnVector; import org.apache.hadoop.hive.ql.exec.vector.VectorAggregationBufferRow; import org.apache.hadoop.hive.ql.exec.vector.VectorizedRowBatch; import org.apache.hadoop.hive.ql.exec.vector.expressions.VectorExpression; -import org.apache.hadoop.hive.ql.exec.vector.expressions.aggregates.VectorAggregateExpression.AggregationBuffer; import org.apache.hadoop.hive.ql.metadata.HiveException; import org.apache.hadoop.hive.ql.plan.AggregationDesc; import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFEvaluator; @@ -34,10 +34,12 @@ import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector; import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory; import org.apache.hadoop.io.BytesWritable; -import org.apache.hive.common.util.BloomFilter; +import org.apache.hive.common.util.Bloom1Filter; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; public class VectorUDAFBloomFilterMerge extends VectorAggregateExpression { - + private static final Logger LOG = LoggerFactory.getLogger(VectorUDAFBloomFilterMerge.class); private static final long serialVersionUID = 1L; private long expectedEntries = -1; @@ -54,9 +56,10 @@ public Aggregation(long expectedEntries) { try { - BloomFilter bf = new BloomFilter(expectedEntries); + Bloom1Filter bf = new Bloom1Filter(expectedEntries); ByteArrayOutputStream bytesOut = new ByteArrayOutputStream(); - BloomFilter.serialize(bytesOut, bf); + Bloom1Filter.serialize(bytesOut, bf); + LOG.info("Bloomfilter serialized size: {}", LlapUtil.humanReadableByteCount(bytesOut.size())); bfBytes = bytesOut.toByteArray(); } catch (Exception err) { throw new IllegalArgumentException("Error creating aggregation buffer", err); @@ -71,7 +74,7 @@ public int getVariableSize() { @Override public void reset() { // Do not change the initial bytes which contain NumHashFunctions/NumBits! - Arrays.fill(bfBytes, BloomFilter.START_OF_SERIALIZED_LONGS, bfBytes.length, (byte) 0); + Arrays.fill(bfBytes, Bloom1Filter.START_OF_SERIALIZED_LONGS, bfBytes.length, (byte) 0); } } @@ -362,7 +365,7 @@ void processValue(Aggregation myagg, ColumnVector columnVector, int i) { // BloomFilter.mergeBloomFilterBytes() does a simple byte ORing // which should be faster than deserialize/merge. BytesColumnVector inputColumn = (BytesColumnVector) columnVector; - BloomFilter.mergeBloomFilterBytes(myagg.bfBytes, 0, myagg.bfBytes.length, + Bloom1Filter.mergeBloomFilterBytes(myagg.bfBytes, 0, myagg.bfBytes.length, inputColumn.vector[i], inputColumn.start[i], inputColumn.length[i]); } } diff --git a/ql/src/java/org/apache/hadoop/hive/ql/udf/generic/GenericUDAFBloomFilter.java b/ql/src/java/org/apache/hadoop/hive/ql/udf/generic/GenericUDAFBloomFilter.java index 2413ae6..5902e92 100644 --- a/ql/src/java/org/apache/hadoop/hive/ql/udf/generic/GenericUDAFBloomFilter.java +++ b/ql/src/java/org/apache/hadoop/hive/ql/udf/generic/GenericUDAFBloomFilter.java @@ -19,6 +19,7 @@ package org.apache.hadoop.hive.ql.udf.generic; import org.apache.hadoop.hive.common.type.HiveDecimal; +import org.apache.hadoop.hive.llap.LlapUtil; import org.apache.hadoop.hive.ql.exec.Operator; import org.apache.hadoop.hive.ql.exec.SelectOperator; import org.apache.hadoop.hive.ql.exec.UDFArgumentTypeException; @@ -28,18 +29,15 @@ import org.apache.hadoop.hive.ql.plan.ExprNodeColumnDesc; import org.apache.hadoop.hive.ql.plan.ExprNodeDescUtils; import org.apache.hadoop.hive.ql.plan.Statistics; -import org.apache.hadoop.hive.ql.plan.Statistics.State; import org.apache.hadoop.hive.serde2.io.DateWritable; import org.apache.hadoop.hive.serde2.io.HiveDecimalWritable; import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector; import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector; import org.apache.hadoop.hive.serde2.objectinspector.primitive.*; -import org.apache.hadoop.hive.serde2.typeinfo.PrimitiveTypeInfo; import org.apache.hadoop.hive.serde2.typeinfo.TypeInfo; -import org.apache.hadoop.hive.serde2.typeinfo.TypeInfoFactory; import org.apache.hadoop.io.BytesWritable; import org.apache.hadoop.io.Text; -import org.apache.hive.common.util.BloomFilter; +import org.apache.hive.common.util.Bloom1Filter; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -106,14 +104,16 @@ public ObjectInspector init(Mode m, ObjectInspector[] parameters) throws HiveExc */ @AggregationType(estimable = true) static class BloomFilterBuf extends AbstractAggregationBuffer { - BloomFilter bloomFilter; + Bloom1Filter bloomFilter; public BloomFilterBuf(long expectedEntries, long maxEntries) { if (expectedEntries > maxEntries) { - bloomFilter = new BloomFilter(1); + bloomFilter = new Bloom1Filter(1); } else { - bloomFilter = new BloomFilter(expectedEntries); + bloomFilter = new Bloom1Filter(expectedEntries); } + + LOG.info("Bloomfilter initialized with expectedEntries: {} sizeInBytes: {}", expectedEntries, estimate()); } @Override @@ -147,7 +147,7 @@ public void iterate(AggregationBuffer agg, Object[] parameters) throws HiveExcep return; } - BloomFilter bf = ((BloomFilterBuf)agg).bloomFilter; + Bloom1Filter bf = ((BloomFilterBuf)agg).bloomFilter; // Add the expression into the BloomFilter switch (inputOI.getPrimitiveCategory()) { @@ -231,7 +231,7 @@ public void merge(AggregationBuffer agg, Object partial) throws HiveException { ByteArrayInputStream in = new ByteArrayInputStream(bytes.getBytes()); // Deserialze the bloomfilter try { - BloomFilter bf = BloomFilter.deserialize(in); + Bloom1Filter bf = Bloom1Filter.deserialize(in); ((BloomFilterBuf)agg).bloomFilter.merge(bf); } catch (IOException e) { throw new HiveException(e); @@ -242,7 +242,8 @@ public void merge(AggregationBuffer agg, Object partial) throws HiveException { public Object terminate(AggregationBuffer agg) throws HiveException { result.reset(); try { - BloomFilter.serialize(result, ((BloomFilterBuf)agg).bloomFilter); + Bloom1Filter.serialize(result, ((BloomFilterBuf)agg).bloomFilter); + LOG.info("Bloomfilter serialized size: {}", LlapUtil.humanReadableByteCount(result.size())); } catch (IOException e) { throw new HiveException(e); } diff --git a/ql/src/java/org/apache/hadoop/hive/ql/udf/generic/GenericUDFInBloomFilter.java b/ql/src/java/org/apache/hadoop/hive/ql/udf/generic/GenericUDFInBloomFilter.java index 3e6e069..6ea749c 100644 --- a/ql/src/java/org/apache/hadoop/hive/ql/udf/generic/GenericUDFInBloomFilter.java +++ b/ql/src/java/org/apache/hadoop/hive/ql/udf/generic/GenericUDFInBloomFilter.java @@ -33,7 +33,7 @@ import org.apache.hadoop.hive.serde2.typeinfo.PrimitiveTypeInfo; import org.apache.hadoop.io.BytesWritable; import org.apache.hadoop.io.Text; -import org.apache.hive.common.util.BloomFilter; +import org.apache.hive.common.util.Bloom1Filter; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -50,7 +50,7 @@ private transient ObjectInspector valObjectInspector; private transient ObjectInspector bloomFilterObjectInspector; - private transient BloomFilter bloomFilter; + private transient Bloom1Filter bloomFilter; private transient boolean initializedBloomFilter; private transient byte[] scratchBuffer = new byte[HiveDecimal.SCRATCH_BUFFER_LEN_TO_BYTES]; @@ -99,7 +99,7 @@ public Object evaluate(DeferredObject[] arguments) throws HiveException { BytesWritable bw = (BytesWritable) arguments[1].get(); byte[] bytes = new byte[bw.getLength()]; System.arraycopy(bw.getBytes(), 0, bytes, 0, bw.getLength()); - bloomFilter = BloomFilter.deserialize(new ByteArrayInputStream(bytes)); + bloomFilter = Bloom1Filter.deserialize(new ByteArrayInputStream(bytes)); } catch ( IOException e) { throw new HiveException(e); } diff --git a/storage-api/src/java/org/apache/hive/common/util/Bloom1Filter.java b/storage-api/src/java/org/apache/hive/common/util/Bloom1Filter.java new file mode 100644 index 0000000..2c68254 --- /dev/null +++ b/storage-api/src/java/org/apache/hive/common/util/Bloom1Filter.java @@ -0,0 +1,469 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.hive.common.util; + +import java.io.DataInputStream; +import java.io.DataOutputStream; +import java.io.IOException; +import java.io.InputStream; +import java.io.OutputStream; +import java.util.Arrays; +import java.util.List; + +/** + * BloomFilter is a probabilistic data structure for set membership check. BloomFilters are + * highly space efficient when compared to using a HashSet. Because of the probabilistic nature of + * bloom filter false positive (element not present in bloom filter but test() says true) are + * possible but false negatives are not possible (if element is present then test() will never + * say false). The false positive probability is configurable (default: 5%) depending on which + * storage requirement may increase or decrease. Lower the false positive probability greater + * is the space requirement. + * Bloom filters are sensitive to number of elements that will be inserted in the bloom filter. + * During the creation of bloom filter expected number of entries must be specified. If the number + * of insertions exceed the specified initial number of entries then false positive probability will + * increase accordingly. + * + * Internally, this implementation of bloom filter uses Murmur3 fast non-cryptographic hash + * algorithm. Although Murmur2 is slightly faster than Murmur3 in Java, it suffers from hash + * collisions for specific sequence of repeating bytes. Check the following link for more info + * https://code.google.com/p/smhasher/wiki/MurmurHash2Flaw + */ +public class Bloom1Filter { + private byte[] BYTE_ARRAY_4 = new byte[4]; + private byte[] BYTE_ARRAY_8 = new byte[8]; + private static final double DEFAULT_FPP = 0.05; + private BitSet bitSet; + private long m; + private int k; + + public Bloom1Filter(long maxNumEntries) { + this(maxNumEntries, DEFAULT_FPP); + } + + static void checkArgument(boolean expression, String message) { + if (!expression) { + throw new IllegalArgumentException(message); + } + } + + public Bloom1Filter(long maxNumEntries, double fpp) { + checkArgument(maxNumEntries > 0, "expectedEntries should be > 0"); + checkArgument(fpp > 0.0 && fpp < 1.0, "False positive probability should be > 0.0 & < 1.0"); + int nb = (int) optimalNumOfBits(maxNumEntries, fpp); + // make 'm' multiple of 64 + this.m = nb + (Long.SIZE - (nb % Long.SIZE)); + this.k = optimalNumOfHashFunctions(maxNumEntries, m); + this.bitSet = new BitSet(m); + } + + // deserialize bloomfilter. see serialize() for the format. + public Bloom1Filter(List serializedBloom) { + this(serializedBloom.get(0), Double.longBitsToDouble(serializedBloom.get(1))); + List bitSet = serializedBloom.subList(2, serializedBloom.size()); + long[] data = new long[bitSet.size()]; + for (int i = 0; i < bitSet.size(); i++) { + data[i] = bitSet.get(i); + } + this.bitSet = new BitSet(data); + } + + /** + * A constructor to support rebuilding the BloomFilter from a serialized representation. + * @param bits the serialized bits + * @param numFuncs the number of functions used + */ + public Bloom1Filter(long[] bits, int numBits, int numFuncs) { + this(bits, numFuncs); + this.m = numBits; + } + + /** + * A constructor to support rebuilding the BloomFilter from a serialized representation. + * @param bits the serialized bits + * @param numFuncs the number of functions used + */ + public Bloom1Filter(long[] bits, int numFuncs) { + bitSet = new BitSet(bits); + this.m = (int) bitSet.bitSize(); + this.k = numFuncs; + } + + static int optimalNumOfHashFunctions(long n, long m) { + return Math.max(1, (int) Math.round((double) m / n * Math.log(2))); + } + + static int optimalNumOfBits(long n, double p) { + return (int) (-n * Math.log(p) / (Math.log(2) * Math.log(2))); + } + + public void add(byte[] val) { + addBytes(val); + } + + public void addBytes(byte[] val, int offset, int length) { + // We use the trick mentioned in "Less Hashing, Same Performance: Building a Better Bloom Filter" + // by Kirsch et.al. From abstract 'only two hash functions are necessary to effectively + // implement a Bloom filter without any loss in the asymptotic false positive probability' + + // Lets split up 64-bit hashcode into two 32-bit hash codes and employ the technique mentioned + // in the above paper + long hash64 = val == null ? Murmur3.NULL_HASHCODE : + Murmur3.hash64(val, offset, length); + addHash(hash64); + } + + public void addBytes(byte[] val) { + addBytes(val, 0, val.length); + } + + private void addHash(long hash64) { + // We use the trick mentioned in "Less Hashing, Same Performance: Building a Better Bloom Filter" + // by Kirsch et.al. From abstract 'only two hash functions are necessary to effectively + // implement a Bloom filter without any loss in the asymptotic false positive probability' + + // Lets split up 64-bit hashcode into two 32-bit hashcodes and employ the technique mentioned + // in the above paper + int hash1 = (int) hash64; + int hash2 = (int) (hash64 >>> 32); + + int firstHash = hash1 + hash2; + // hashcode should be positive, flip all the bits if it's negative + if (firstHash < 0) { + firstHash = ~firstHash; + } + + int wordIdx = firstHash % bitSet.data.length; + long word = bitSet.data[wordIdx]; + long mask = (1L << Long.SIZE - 1); + for (int i = 2; i <= k; i++) { + int combinedHash = hash1 + (i * hash2); + // hashcode should be positive, flip all the bits if it's negative + if (combinedHash < 0) { + combinedHash = ~combinedHash; + } + int pos = combinedHash & (Long.SIZE - 1); + mask |= (1L << pos); + } + bitSet.getData()[wordIdx] = word | mask; + } + + public void addString(String val) { + addBytes(val.getBytes()); + } + + public void addByte(byte val) { + addBytes(new byte[]{val}); + } + + public void addInt(int val) { + // puts int in little endian order + addBytes(intToByteArrayLE(val)); + } + + + public void addLong(long val) { + // puts long in little endian order + addBytes(longToByteArrayLE(val)); + } + + public void addFloat(float val) { + addInt(Float.floatToIntBits(val)); + } + + public void addDouble(double val) { + addLong(Double.doubleToLongBits(val)); + } + + public boolean test(byte[] val) { + return testBytes(val); + } + + public boolean testBytes(byte[] val) { + return testBytes(val, 0, val.length); + } + + public boolean testBytes(byte[] val, int offset, int length) { + long hash64 = val == null ? Murmur3.NULL_HASHCODE : + Murmur3.hash64(val, offset, length); + return testHash(hash64); + } + + private boolean testHash(long hash64) { + int hash1 = (int) hash64; + int hash2 = (int) (hash64 >>> 32); + + int firstHash = hash1 + hash2; + // hashcode should be positive, flip all the bits if it's negative + if (firstHash < 0) { + firstHash = ~firstHash; + } + int wordIdx = firstHash % bitSet.data.length; + long word = bitSet.data[wordIdx]; + long mask = (1L << Long.SIZE - 1); + for (int i = 2; i <= k; i++) { + int combinedHash = hash1 + (i * hash2); + // hashcode should be positive, flip all the bits if it's negative + if (combinedHash < 0) { + combinedHash = ~combinedHash; + } + int pos = combinedHash & (Long.SIZE - 1); + mask |= (1L << pos); + } + + return (word & mask) == mask; + } + + public boolean testString(String val) { + return testBytes(val.getBytes()); + } + + public boolean testByte(byte val) { + return testBytes(new byte[]{val}); + } + + public boolean testInt(int val) { + return testBytes(intToByteArrayLE(val)); + } + + public boolean testLong(long val) { + return testBytes(longToByteArrayLE(val)); + } + + public boolean testFloat(float val) { + return testInt(Float.floatToIntBits(val)); + } + + public boolean testDouble(double val) { + return testLong(Double.doubleToLongBits(val)); + } + + private byte[] intToByteArrayLE(int val) { + BYTE_ARRAY_4[0] = (byte) (val >> 0); + BYTE_ARRAY_4[1] = (byte) (val >> 8); + BYTE_ARRAY_4[2] = (byte) (val >> 16); + BYTE_ARRAY_4[3] = (byte) (val >> 24); + return BYTE_ARRAY_4; + } + + private byte[] longToByteArrayLE(long val) { + BYTE_ARRAY_8[0] = (byte) (val >> 0); + BYTE_ARRAY_8[1] = (byte) (val >> 8); + BYTE_ARRAY_8[2] = (byte) (val >> 16); + BYTE_ARRAY_8[3] = (byte) (val >> 24); + BYTE_ARRAY_8[4] = (byte) (val >> 32); + BYTE_ARRAY_8[5] = (byte) (val >> 40); + BYTE_ARRAY_8[6] = (byte) (val >> 48); + BYTE_ARRAY_8[7] = (byte) (val >> 56); + return BYTE_ARRAY_8; + } + + public long sizeInBytes() { + return getBitSize() / 8; + } + + public int getBitSize() { + return bitSet.getData().length * Long.SIZE; + } + + public int getNumHashFunctions() { + return k; + } + + public long[] getBitSet() { + return bitSet.getData(); + } + + @Override + public String toString() { + return "m: " + m + " k: " + k; + } + + /** + * Merge the specified bloom filter with current bloom filter. + * + * @param that - bloom filter to merge + */ + public void merge(Bloom1Filter that) { + if (this != that && this.m == that.m && this.k == that.k) { + this.bitSet.putAll(that.bitSet); + } else { + throw new IllegalArgumentException("BloomFilters are not compatible for merging." + + " this - " + this.toString() + " that - " + that.toString()); + } + } + + public void reset() { + this.bitSet.clear(); + } + + /** + * Serialize a bloom filter + * @param out output stream to write to + * @param bloomFilter BloomFilter that needs to be seralized + */ + public static void serialize(OutputStream out, Bloom1Filter bloomFilter) throws IOException { + /** + * Serialized BloomFilter format: + * 1 byte for the number of hash functions. + * 1 big endian int(That is how OutputStream works) for the number of longs in the bitset + * big endina longs in the BloomFilter bitset + */ + DataOutputStream dataOutputStream = new DataOutputStream(out); + dataOutputStream.writeByte(bloomFilter.k); + dataOutputStream.writeInt((int) bloomFilter.m); + for (long value : bloomFilter.getBitSet()) { + dataOutputStream.writeLong(value); + } + } + + /** + * Deserialize a bloom filter + * Read a byte stream, which was written by {@linkplain #serialize(OutputStream, Bloom1Filter)} + * into a {@code BloomFilter} + * @param in input bytestream + * @return deserialized BloomFilter + */ + public static Bloom1Filter deserialize(InputStream in) throws IOException { + if (in == null) { + throw new IOException("Input stream is null"); + } + + try { + DataInputStream dataInputStream = new DataInputStream(in); + int numHashFunc = dataInputStream.readByte(); + int numBits = dataInputStream.readInt(); + int sz = (numBits/Long.SIZE); + long[] data = new long[sz]; + for (int i = 0; i < sz; i++) { + data[i] = dataInputStream.readLong(); + } + return new Bloom1Filter(data, numBits, numHashFunc); + } catch (RuntimeException e) { + IOException io = new IOException( "Unable to deserialize BloomFilter"); + io.initCause(e); + throw io; + } + } + + // Given a byte array consisting of a serialized BloomFilter, gives the offset (from 0) + // for the start of the serialized long values that make up the bitset. + // NumHashFunctions (1 byte) + NumBits (4 bytes) + public static final int START_OF_SERIALIZED_LONGS = 5; + + /** + * Merges BloomFilter bf2 into bf1. + * Assumes 2 BloomFilters with the same size/hash functions are serialized to byte arrays + * @param bf1Bytes + * @param bf1Start + * @param bf1Length + * @param bf2Bytes + * @param bf2Start + * @param bf2Length + */ + public static void mergeBloomFilterBytes( + byte[] bf1Bytes, int bf1Start, int bf1Length, + byte[] bf2Bytes, int bf2Start, int bf2Length) { + if (bf1Length != bf2Length) { + throw new IllegalArgumentException("bf1Length " + bf1Length + " does not match bf2Length " + bf2Length); + } + + // Validation on the bitset size/3 hash functions. + for (int idx = 0; idx < START_OF_SERIALIZED_LONGS; ++idx) { + if (bf1Bytes[bf1Start + idx] != bf2Bytes[bf2Start + idx]) { + throw new IllegalArgumentException("bf1 NumHashFunctions/NumBits does not match bf2"); + } + } + + // Just bitwise-OR the bits together - size/# functions should be the same, + // rest of the data is serialized long values for the bitset which are supposed to be bitwise-ORed. + for (int idx = START_OF_SERIALIZED_LONGS; idx < bf1Length; ++idx) { + bf1Bytes[bf1Start + idx] |= bf2Bytes[bf2Start + idx]; + } + } + + /** + * Bare metal bit set implementation. For performance reasons, this implementation does not check + * for index bounds nor expand the bit set size if the specified index is greater than the size. + */ + public static class BitSet { + private final long[] data; + + public BitSet(long bits) { + this(new long[(int) Math.ceil((double) bits / (double) Long.SIZE)]); + } + + /** + * Deserialize long array as bit set. + * + * @param data - bit array + */ + public BitSet(long[] data) { + assert data.length > 0 : "data length is zero!"; + this.data = data; + } + + /** + * Sets the bit at specified index. + * + * @param index - position + */ + public void set(int index) { + data[index >>> 6] |= (1L << index); + } + + /** + * Returns true if the bit is set in the specified index. + * + * @param index - position + * @return - value at the bit position + */ + public boolean get(int index) { + return (data[index >>> 6] & (1L << index)) != 0; + } + + /** + * Number of bits + */ + public long bitSize() { + return (long) data.length * Long.SIZE; + } + + public long[] getData() { + return data; + } + + /** + * Combines the two BitArrays using bitwise OR. + */ + public void putAll(Bloom1Filter.BitSet array) { + assert data.length == array.data.length : + "BitArrays must be of equal length (" + data.length + "!= " + array.data.length + ")"; + for (int i = 0; i < data.length; i++) { + data[i] |= array.data[i]; + } + } + + /** + * Clear the bit set. + */ + public void clear() { + Arrays.fill(data, 0); + } + } +} diff --git a/storage-api/src/test/org/apache/hive/common/util/TestBloom1Filter.java b/storage-api/src/test/org/apache/hive/common/util/TestBloom1Filter.java new file mode 100644 index 0000000..51ab5ad --- /dev/null +++ b/storage-api/src/test/org/apache/hive/common/util/TestBloom1Filter.java @@ -0,0 +1,585 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.hive.common.util; + +import static org.junit.Assert.assertEquals; + +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; +import java.util.ArrayList; +import java.util.Random; + +import org.junit.Assert; +import org.junit.Test; + +/** + * + */ +public class TestBloom1Filter { + private static final int COUNT = 100; + Random rand = new Random(123); + + @Test(expected = IllegalArgumentException.class) + public void testBloomIllegalArg1() { + Bloom1Filter bf = new Bloom1Filter(0, 0); + } + + @Test(expected = IllegalArgumentException.class) + public void testBloomIllegalArg2() { + Bloom1Filter bf = new Bloom1Filter(0, 0.1); + } + + @Test(expected = IllegalArgumentException.class) + public void testBloomIllegalArg3() { + Bloom1Filter bf = new Bloom1Filter(1, 0.0); + } + + @Test(expected = IllegalArgumentException.class) + public void testBloomIllegalArg4() { + Bloom1Filter bf = new Bloom1Filter(1, 1.0); + } + + @Test(expected = IllegalArgumentException.class) + public void testBloomIllegalArg5() { + Bloom1Filter bf = new Bloom1Filter(-1, -1); + } + + + @Test + public void testBloomNumBits() { + assertEquals(0, Bloom1Filter.optimalNumOfBits(0, 0)); + assertEquals(0, Bloom1Filter.optimalNumOfBits(0, 1)); + assertEquals(0, Bloom1Filter.optimalNumOfBits(1, 1)); + assertEquals(7, Bloom1Filter.optimalNumOfBits(1, 0.03)); + assertEquals(72, Bloom1Filter.optimalNumOfBits(10, 0.03)); + assertEquals(729, Bloom1Filter.optimalNumOfBits(100, 0.03)); + assertEquals(7298, Bloom1Filter.optimalNumOfBits(1000, 0.03)); + assertEquals(72984, Bloom1Filter.optimalNumOfBits(10000, 0.03)); + assertEquals(729844, Bloom1Filter.optimalNumOfBits(100000, 0.03)); + assertEquals(7298440, Bloom1Filter.optimalNumOfBits(1000000, 0.03)); + assertEquals(6235224, Bloom1Filter.optimalNumOfBits(1000000, 0.05)); + assertEquals(1870567268, Bloom1Filter.optimalNumOfBits(300000000, 0.05)); + assertEquals(1437758756, Bloom1Filter.optimalNumOfBits(300000000, 0.1)); + assertEquals(432808512, Bloom1Filter.optimalNumOfBits(300000000, 0.5)); + assertEquals(1393332198, Bloom1Filter.optimalNumOfBits(3000000000L, 0.8)); + assertEquals(657882327, Bloom1Filter.optimalNumOfBits(3000000000L, 0.9)); + assertEquals(0, Bloom1Filter.optimalNumOfBits(3000000000L, 1)); + } + + @Test + public void testBloomNumHashFunctions() { + assertEquals(1, Bloom1Filter.optimalNumOfHashFunctions(-1, -1)); + assertEquals(1, Bloom1Filter.optimalNumOfHashFunctions(0, 0)); + assertEquals(1, Bloom1Filter.optimalNumOfHashFunctions(10, 0)); + assertEquals(1, Bloom1Filter.optimalNumOfHashFunctions(10, 10)); + assertEquals(7, Bloom1Filter.optimalNumOfHashFunctions(10, 100)); + assertEquals(1, Bloom1Filter.optimalNumOfHashFunctions(100, 100)); + assertEquals(1, Bloom1Filter.optimalNumOfHashFunctions(1000, 100)); + assertEquals(1, Bloom1Filter.optimalNumOfHashFunctions(10000, 100)); + assertEquals(1, Bloom1Filter.optimalNumOfHashFunctions(100000, 100)); + assertEquals(1, Bloom1Filter.optimalNumOfHashFunctions(1000000, 100)); + } + + @Test + public void testBloom1FilterBytes() { + Bloom1Filter bf = new Bloom1Filter(10000); + byte[] val = new byte[]{1, 2, 3}; + byte[] val1 = new byte[]{1, 2, 3, 4}; + byte[] val2 = new byte[]{1, 2, 3, 4, 5}; + byte[] val3 = new byte[]{1, 2, 3, 4, 5, 6}; + + assertEquals(false, bf.test(val)); + assertEquals(false, bf.test(val1)); + assertEquals(false, bf.test(val2)); + assertEquals(false, bf.test(val3)); + bf.add(val); + assertEquals(true, bf.test(val)); + assertEquals(false, bf.test(val1)); + assertEquals(false, bf.test(val2)); + assertEquals(false, bf.test(val3)); + bf.add(val1); + assertEquals(true, bf.test(val)); + assertEquals(true, bf.test(val1)); + assertEquals(false, bf.test(val2)); + assertEquals(false, bf.test(val3)); + bf.add(val2); + assertEquals(true, bf.test(val)); + assertEquals(true, bf.test(val1)); + assertEquals(true, bf.test(val2)); + assertEquals(false, bf.test(val3)); + bf.add(val3); + assertEquals(true, bf.test(val)); + assertEquals(true, bf.test(val1)); + assertEquals(true, bf.test(val2)); + assertEquals(true, bf.test(val3)); + + byte[] randVal = new byte[COUNT]; + for (int i = 0; i < COUNT; i++) { + rand.nextBytes(randVal); + bf.add(randVal); + } + // last value should be present + assertEquals(true, bf.test(randVal)); + // most likely this value should not exist + randVal[0] = 0; + randVal[1] = 0; + randVal[2] = 0; + randVal[3] = 0; + randVal[4] = 0; + assertEquals(false, bf.test(randVal)); + + assertEquals(7800, bf.sizeInBytes()); + } + + @Test + public void testBloom1FilterByte() { + Bloom1Filter bf = new Bloom1Filter(10000); + byte val = Byte.MIN_VALUE; + byte val1 = 1; + byte val2 = 2; + byte val3 = Byte.MAX_VALUE; + + assertEquals(false, bf.testLong(val)); + assertEquals(false, bf.testLong(val1)); + assertEquals(false, bf.testLong(val2)); + assertEquals(false, bf.testLong(val3)); + bf.addLong(val); + assertEquals(true, bf.testLong(val)); + assertEquals(false, bf.testLong(val1)); + assertEquals(false, bf.testLong(val2)); + assertEquals(false, bf.testLong(val3)); + bf.addLong(val1); + assertEquals(true, bf.testLong(val)); + assertEquals(true, bf.testLong(val1)); + assertEquals(false, bf.testLong(val2)); + assertEquals(false, bf.testLong(val3)); + bf.addLong(val2); + assertEquals(true, bf.testLong(val)); + assertEquals(true, bf.testLong(val1)); + assertEquals(true, bf.testLong(val2)); + assertEquals(false, bf.testLong(val3)); + bf.addLong(val3); + assertEquals(true, bf.testLong(val)); + assertEquals(true, bf.testLong(val1)); + assertEquals(true, bf.testLong(val2)); + assertEquals(true, bf.testLong(val3)); + + byte randVal = 0; + for (int i = 0; i < COUNT; i++) { + randVal = (byte) rand.nextInt(Byte.MAX_VALUE); + bf.addLong(randVal); + } + // last value should be present + assertEquals(true, bf.testLong(randVal)); + // most likely this value should not exist + assertEquals(false, bf.testLong((byte) -120)); + + assertEquals(7800, bf.sizeInBytes()); + } + + @Test + public void testBloom1FilterInt() { + Bloom1Filter bf = new Bloom1Filter(10000); + int val = Integer.MIN_VALUE; + int val1 = 1; + int val2 = 2; + int val3 = Integer.MAX_VALUE; + + assertEquals(false, bf.testLong(val)); + assertEquals(false, bf.testLong(val1)); + assertEquals(false, bf.testLong(val2)); + assertEquals(false, bf.testLong(val3)); + bf.addLong(val); + assertEquals(true, bf.testLong(val)); + assertEquals(false, bf.testLong(val1)); + assertEquals(false, bf.testLong(val2)); + assertEquals(false, bf.testLong(val3)); + bf.addLong(val1); + assertEquals(true, bf.testLong(val)); + assertEquals(true, bf.testLong(val1)); + assertEquals(false, bf.testLong(val2)); + assertEquals(false, bf.testLong(val3)); + bf.addLong(val2); + assertEquals(true, bf.testLong(val)); + assertEquals(true, bf.testLong(val1)); + assertEquals(true, bf.testLong(val2)); + assertEquals(false, bf.testLong(val3)); + bf.addLong(val3); + assertEquals(true, bf.testLong(val)); + assertEquals(true, bf.testLong(val1)); + assertEquals(true, bf.testLong(val2)); + assertEquals(true, bf.testLong(val3)); + + int randVal = 0; + for (int i = 0; i < COUNT; i++) { + randVal = rand.nextInt(); + bf.addLong(randVal); + } + // last value should be present + assertEquals(true, bf.testLong(randVal)); + // most likely this value should not exist + assertEquals(false, bf.testLong(-120)); + + assertEquals(7800, bf.sizeInBytes()); + } + + @Test + public void testBloom1FilterLong() { + Bloom1Filter bf = new Bloom1Filter(10000); + long val = Long.MIN_VALUE; + long val1 = 1; + long val2 = 2; + long val3 = Long.MAX_VALUE; + + assertEquals(false, bf.testLong(val)); + assertEquals(false, bf.testLong(val1)); + assertEquals(false, bf.testLong(val2)); + assertEquals(false, bf.testLong(val3)); + bf.addLong(val); + assertEquals(true, bf.testLong(val)); + assertEquals(false, bf.testLong(val1)); + assertEquals(false, bf.testLong(val2)); + assertEquals(false, bf.testLong(val3)); + bf.addLong(val1); + assertEquals(true, bf.testLong(val)); + assertEquals(true, bf.testLong(val1)); + assertEquals(false, bf.testLong(val2)); + assertEquals(false, bf.testLong(val3)); + bf.addLong(val2); + assertEquals(true, bf.testLong(val)); + assertEquals(true, bf.testLong(val1)); + assertEquals(true, bf.testLong(val2)); + assertEquals(false, bf.testLong(val3)); + bf.addLong(val3); + assertEquals(true, bf.testLong(val)); + assertEquals(true, bf.testLong(val1)); + assertEquals(true, bf.testLong(val2)); + assertEquals(true, bf.testLong(val3)); + + long randVal = 0; + for (int i = 0; i < COUNT; i++) { + randVal = rand.nextLong(); + bf.addLong(randVal); + } + // last value should be present + assertEquals(true, bf.testLong(randVal)); + // most likely this value should not exist + assertEquals(false, bf.testLong(-120)); + + assertEquals(7800, bf.sizeInBytes()); + } + + @Test + public void testBloom1FilterFloat() { + Bloom1Filter bf = new Bloom1Filter(10000); + float val = Float.MIN_VALUE; + float val1 = 1.1f; + float val2 = 2.2f; + float val3 = Float.MAX_VALUE; + + assertEquals(false, bf.testDouble(val)); + assertEquals(false, bf.testDouble(val1)); + assertEquals(false, bf.testDouble(val2)); + assertEquals(false, bf.testDouble(val3)); + bf.addDouble(val); + assertEquals(true, bf.testDouble(val)); + assertEquals(false, bf.testDouble(val1)); + assertEquals(false, bf.testDouble(val2)); + assertEquals(false, bf.testDouble(val3)); + bf.addDouble(val1); + assertEquals(true, bf.testDouble(val)); + assertEquals(true, bf.testDouble(val1)); + assertEquals(false, bf.testDouble(val2)); + assertEquals(false, bf.testDouble(val3)); + bf.addDouble(val2); + assertEquals(true, bf.testDouble(val)); + assertEquals(true, bf.testDouble(val1)); + assertEquals(true, bf.testDouble(val2)); + assertEquals(false, bf.testDouble(val3)); + bf.addDouble(val3); + assertEquals(true, bf.testDouble(val)); + assertEquals(true, bf.testDouble(val1)); + assertEquals(true, bf.testDouble(val2)); + assertEquals(true, bf.testDouble(val3)); + + float randVal = 0; + for (int i = 0; i < COUNT; i++) { + randVal = rand.nextFloat(); + bf.addDouble(randVal); + } + // last value should be present + assertEquals(true, bf.testDouble(randVal)); + // most likely this value should not exist + assertEquals(false, bf.testDouble(-120.2f)); + + assertEquals(7800, bf.sizeInBytes()); + } + + @Test + public void testBloom1FilterDouble() { + Bloom1Filter bf = new Bloom1Filter(10000); + double val = Double.MIN_VALUE; + double val1 = 1.1d; + double val2 = 2.2d; + double val3 = Double.MAX_VALUE; + + assertEquals(false, bf.testDouble(val)); + assertEquals(false, bf.testDouble(val1)); + assertEquals(false, bf.testDouble(val2)); + assertEquals(false, bf.testDouble(val3)); + bf.addDouble(val); + assertEquals(true, bf.testDouble(val)); + assertEquals(false, bf.testDouble(val1)); + assertEquals(false, bf.testDouble(val2)); + assertEquals(false, bf.testDouble(val3)); + bf.addDouble(val1); + assertEquals(true, bf.testDouble(val)); + assertEquals(true, bf.testDouble(val1)); + assertEquals(false, bf.testDouble(val2)); + assertEquals(false, bf.testDouble(val3)); + bf.addDouble(val2); + assertEquals(true, bf.testDouble(val)); + assertEquals(true, bf.testDouble(val1)); + assertEquals(true, bf.testDouble(val2)); + assertEquals(false, bf.testDouble(val3)); + bf.addDouble(val3); + assertEquals(true, bf.testDouble(val)); + assertEquals(true, bf.testDouble(val1)); + assertEquals(true, bf.testDouble(val2)); + assertEquals(true, bf.testDouble(val3)); + + double randVal = 0; + for (int i = 0; i < COUNT; i++) { + randVal = rand.nextDouble(); + bf.addDouble(randVal); + } + // last value should be present + assertEquals(true, bf.testDouble(randVal)); + // most likely this value should not exist + assertEquals(false, bf.testDouble(-120.2d)); + + assertEquals(7800, bf.sizeInBytes()); + } + + @Test + public void testBloom1FilterString() { + Bloom1Filter bf = new Bloom1Filter(100000); + String val = "bloo"; + String val1 = "bloom fil"; + String val2 = "bloom filter"; + String val3 = "cuckoo filter"; + + assertEquals(false, bf.testString(val)); + assertEquals(false, bf.testString(val1)); + assertEquals(false, bf.testString(val2)); + assertEquals(false, bf.testString(val3)); + bf.addString(val); + assertEquals(true, bf.testString(val)); + assertEquals(false, bf.testString(val1)); + assertEquals(false, bf.testString(val2)); + assertEquals(false, bf.testString(val3)); + bf.addString(val1); + assertEquals(true, bf.testString(val)); + assertEquals(true, bf.testString(val1)); + assertEquals(false, bf.testString(val2)); + assertEquals(false, bf.testString(val3)); + bf.addString(val2); + assertEquals(true, bf.testString(val)); + assertEquals(true, bf.testString(val1)); + assertEquals(true, bf.testString(val2)); + assertEquals(false, bf.testString(val3)); + bf.addString(val3); + assertEquals(true, bf.testString(val)); + assertEquals(true, bf.testString(val1)); + assertEquals(true, bf.testString(val2)); + assertEquals(true, bf.testString(val3)); + + long randVal = 0; + for (int i = 0; i < COUNT; i++) { + randVal = rand.nextLong(); + bf.addString(Long.toString(randVal)); + } + // last value should be present + assertEquals(true, bf.testString(Long.toString(randVal))); + // most likely this value should not exist + assertEquals(false, bf.testString(Long.toString(-120))); + + assertEquals(77944, bf.sizeInBytes()); + } + + @Test + public void testMerge() { + Bloom1Filter bf = new Bloom1Filter(10000); + String val = "bloo"; + String val1 = "bloom fil"; + String val2 = "bloom filter"; + String val3 = "cuckoo filter"; + bf.addString(val); + bf.addString(val1); + bf.addString(val2); + bf.addString(val3); + + Bloom1Filter bf2 = new Bloom1Filter(10000); + String v = "2_bloo"; + String v1 = "2_bloom fil"; + String v2 = "2_bloom filter"; + String v3 = "2_cuckoo filter"; + bf2.addString(v); + bf2.addString(v1); + bf2.addString(v2); + bf2.addString(v3); + + assertEquals(true, bf.testString(val)); + assertEquals(true, bf.testString(val1)); + assertEquals(true, bf.testString(val2)); + assertEquals(true, bf.testString(val3)); + assertEquals(false, bf.testString(v)); + assertEquals(false, bf.testString(v1)); + assertEquals(false, bf.testString(v2)); + assertEquals(false, bf.testString(v3)); + + bf.merge(bf2); + + assertEquals(true, bf.testString(val)); + assertEquals(true, bf.testString(val1)); + assertEquals(true, bf.testString(val2)); + assertEquals(true, bf.testString(val3)); + assertEquals(true, bf.testString(v)); + assertEquals(true, bf.testString(v1)); + assertEquals(true, bf.testString(v2)); + assertEquals(true, bf.testString(v3)); + } + + @Test + public void testSerialize() throws Exception { + Bloom1Filter bf1 = new Bloom1Filter(10000); + String[] inputs = { + "bloo", + "bloom fil", + "bloom filter", + "cuckoo filter", + }; + + for (String val : inputs) { + bf1.addString(val); + } + + // Serialize/deserialize + ByteArrayOutputStream bytesOut = new ByteArrayOutputStream(); + Bloom1Filter.serialize(bytesOut, bf1); + ByteArrayInputStream bytesIn = new ByteArrayInputStream(bytesOut.toByteArray()); + Bloom1Filter bf2 = Bloom1Filter.deserialize(bytesIn); + + for (String val : inputs) { + assertEquals("Testing bf1 with " + val, true, bf1.testString(val)); + assertEquals("Testing bf2 with " + val, true, bf2.testString(val)); + } + } + + @Test + public void testMergeBloom1FilterBytes() throws Exception { + Bloom1Filter bf1 = new Bloom1Filter(10000); + Bloom1Filter bf2 = new Bloom1Filter(10000); + + String[] inputs1 = { + "bloo", + "bloom fil", + "bloom filter", + "cuckoo filter", + }; + + String[] inputs2 = { + "2_bloo", + "2_bloom fil", + "2_bloom filter", + "2_cuckoo filter", + }; + + for (String val : inputs1) { + bf1.addString(val); + } + for (String val : inputs2) { + bf2.addString(val); + } + + ByteArrayOutputStream bytesOut = new ByteArrayOutputStream(); + Bloom1Filter.serialize(bytesOut, bf1); + byte[] bf1Bytes = bytesOut.toByteArray(); + bytesOut.reset(); + Bloom1Filter.serialize(bytesOut, bf1); + byte[] bf2Bytes = bytesOut.toByteArray(); + + // Merge bytes + Bloom1Filter.mergeBloomFilterBytes( + bf1Bytes, 0, bf1Bytes.length, + bf2Bytes, 0, bf2Bytes.length); + + // Deserialize and test + ByteArrayInputStream bytesIn = new ByteArrayInputStream(bf1Bytes, 0, bf1Bytes.length); + Bloom1Filter bfMerged = Bloom1Filter.deserialize(bytesIn); + // All values should pass test + for (String val : inputs1) { + bfMerged.addString(val); + } + for (String val : inputs2) { + bfMerged.addString(val); + } + } + + @Test + public void testMergeBloom1FilterBytesFailureCases() throws Exception { + Bloom1Filter bf1 = new Bloom1Filter(1000); + Bloom1Filter bf2 = new Bloom1Filter(200); + long[] bits = new long[bf1.getBitSet().length]; + Bloom1Filter bf3 = new Bloom1Filter(new long[bf1.getBitSet().length], bf1.getNumHashFunctions() + 1); + + // Serialize to bytes + ByteArrayOutputStream bytesOut = new ByteArrayOutputStream(); + Bloom1Filter.serialize(bytesOut, bf1); + byte[] bf1Bytes = bytesOut.toByteArray(); + + bytesOut.reset(); + Bloom1Filter.serialize(bytesOut, bf2); + byte[] bf2Bytes = bytesOut.toByteArray(); + + bytesOut.reset(); + Bloom1Filter.serialize(bytesOut, bf3); + byte[] bf3Bytes = bytesOut.toByteArray(); + + try { + // this should fail + Bloom1Filter.mergeBloomFilterBytes( + bf1Bytes, 0, bf1Bytes.length, + bf2Bytes, 0, bf2Bytes.length); + Assert.fail("Expected exception not encountered"); + } catch (IllegalArgumentException err) { + // expected + } + + try { + // this should fail + Bloom1Filter.mergeBloomFilterBytes( + bf1Bytes, 0, bf1Bytes.length, + bf3Bytes, 0, bf3Bytes.length); + Assert.fail("Expected exception not encountered"); + } catch (IllegalArgumentException err) { + // expected + } + } +}