diff --git a/metastore/src/java/org/apache/hadoop/hive/metastore/hbase/AggrStatsInvalidatorFilter.java b/metastore/src/java/org/apache/hadoop/hive/metastore/hbase/AggrStatsInvalidatorFilter.java index 4ca4229..2db5c38 100644 --- a/metastore/src/java/org/apache/hadoop/hive/metastore/hbase/AggrStatsInvalidatorFilter.java +++ b/metastore/src/java/org/apache/hadoop/hive/metastore/hbase/AggrStatsInvalidatorFilter.java @@ -18,6 +18,7 @@ */ package org.apache.hadoop.hive.metastore.hbase; +import com.google.common.primitives.Longs; import com.google.protobuf.InvalidProtocolBufferException; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -101,10 +102,8 @@ public ReturnCode filterKeyValue(Cell cell) throws IOException { entry.getTableName().equals(fromCol.getTableName())) { if (bloom == null) { // Now, reconstitute the bloom filter and probe it with each of our partition names - bloom = new BloomFilter( - fromCol.getBloomFilter().getBitsList(), - fromCol.getBloomFilter().getNumBits(), - fromCol.getBloomFilter().getNumFuncs()); + List bitsList = fromCol.getBloomFilter().getBitsList(); + bloom = new BloomFilter(Longs.toArray(bitsList), fromCol.getBloomFilter().getNumFuncs()); } if (bloom.test(entry.getPartName().toByteArray())) { // This is most likely a match, so mark it and quit looking. diff --git a/ql/src/java/org/apache/hadoop/hive/ql/exec/vector/expressions/VectorInBloomFilterColDynamicValue.java b/ql/src/java/org/apache/hadoop/hive/ql/exec/vector/expressions/VectorInBloomFilterColDynamicValue.java index 188a87e..25440d6 100644 --- a/ql/src/java/org/apache/hadoop/hive/ql/exec/vector/expressions/VectorInBloomFilterColDynamicValue.java +++ b/ql/src/java/org/apache/hadoop/hive/ql/exec/vector/expressions/VectorInBloomFilterColDynamicValue.java @@ -19,8 +19,10 @@ package org.apache.hadoop.hive.ql.exec.vector.expressions; import java.io.ByteArrayInputStream; +import java.io.InputStream; import org.apache.hadoop.conf.Configuration; +import org.apache.hadoop.hive.common.io.NonSyncByteArrayInputStream; import org.apache.hadoop.hive.common.type.HiveDecimal; import org.apache.hadoop.hive.ql.exec.vector.BytesColumnVector; import org.apache.hadoop.hive.ql.exec.vector.ColumnVector; @@ -33,20 +35,16 @@ import org.apache.hadoop.hive.ql.exec.vector.VectorizedRowBatch; import org.apache.hadoop.hive.ql.plan.DynamicValue; import org.apache.hadoop.hive.serde2.objectinspector.primitive.BinaryObjectInspector; -import org.apache.hive.common.util.BloomFilter; - -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; +import org.apache.hadoop.io.IOUtils; +import org.apache.hive.common.util.BloomKFilter; public class VectorInBloomFilterColDynamicValue extends VectorExpression { private static final long serialVersionUID = 1L; - private static final Logger LOG = LoggerFactory.getLogger(VectorInBloomFilterColDynamicValue.class); - protected int colNum; protected DynamicValue bloomFilterDynamicValue; protected transient boolean initialized = false; - protected transient BloomFilter bloomFilter; + protected transient BloomKFilter bloomFilter; protected transient BloomFilterCheck bfCheck; public VectorInBloomFilterColDynamicValue(int colNum, DynamicValue bloomFilterDynamicValue) { @@ -90,18 +88,22 @@ public void init(Configuration conf) { } private void initValue() { + InputStream in = null; try { Object val = bloomFilterDynamicValue.getValue(); if (val != null) { BinaryObjectInspector boi = (BinaryObjectInspector) bloomFilterDynamicValue.getObjectInspector(); byte[] bytes = boi.getPrimitiveJavaObject(val); - bloomFilter = BloomFilter.deserialize(new ByteArrayInputStream(bytes)); + in = new NonSyncByteArrayInputStream(bytes); + bloomFilter = BloomKFilter.deserialize(in); } else { bloomFilter = null; } initialized = true; } catch (Exception err) { throw new RuntimeException(err); + } finally { + IOUtils.closeStream(in); } } diff --git a/ql/src/java/org/apache/hadoop/hive/ql/exec/vector/expressions/aggregates/VectorUDAFBloomFilter.java b/ql/src/java/org/apache/hadoop/hive/ql/exec/vector/expressions/aggregates/VectorUDAFBloomFilter.java index 4b3eca09..0e308f9 100644 --- a/ql/src/java/org/apache/hadoop/hive/ql/exec/vector/expressions/aggregates/VectorUDAFBloomFilter.java +++ b/ql/src/java/org/apache/hadoop/hive/ql/exec/vector/expressions/aggregates/VectorUDAFBloomFilter.java @@ -20,7 +20,6 @@ import java.io.ByteArrayOutputStream; import java.io.IOException; -import java.util.Arrays; import org.apache.hadoop.hive.common.type.HiveDecimal; import org.apache.hadoop.hive.ql.exec.vector.BytesColumnVector; @@ -33,25 +32,19 @@ import org.apache.hadoop.hive.ql.exec.vector.VectorExpressionDescriptor; import org.apache.hadoop.hive.ql.exec.vector.VectorizedRowBatch; import org.apache.hadoop.hive.ql.exec.vector.expressions.VectorExpression; -import org.apache.hadoop.hive.ql.exec.vector.expressions.aggregates.VectorAggregateExpression.AggregationBuffer; import org.apache.hadoop.hive.ql.metadata.HiveException; import org.apache.hadoop.hive.ql.plan.AggregationDesc; -import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFBloomFilter; -import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFEvaluator; import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFBloomFilter.GenericUDAFBloomFilterEvaluator; +import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFEvaluator; import org.apache.hadoop.hive.ql.util.JavaDataModel; import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector; import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory; import org.apache.hadoop.io.BytesWritable; -import org.apache.hadoop.io.Text; -import org.apache.hive.common.util.BloomFilter; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; +import org.apache.hadoop.io.IOUtils; +import org.apache.hive.common.util.BloomKFilter; public class VectorUDAFBloomFilter extends VectorAggregateExpression { - private static final Logger LOG = LoggerFactory.getLogger(VectorUDAFBloomFilter.class); - private static final long serialVersionUID = 1L; private long expectedEntries = -1; @@ -66,10 +59,10 @@ private static final class Aggregation implements AggregationBuffer { private static final long serialVersionUID = 1L; - BloomFilter bf; + BloomKFilter bf; public Aggregation(long expectedEntries) { - bf = new BloomFilter(expectedEntries); + bf = new BloomKFilter(expectedEntries); } @Override @@ -363,12 +356,14 @@ public Object evaluateOutput(AggregationBuffer agg) throws HiveException { try { Aggregation bfAgg = (Aggregation) agg; byteStream.reset(); - BloomFilter.serialize(byteStream, bfAgg.bf); + BloomKFilter.serialize(byteStream, bfAgg.bf); byte[] bytes = byteStream.toByteArray(); bw.set(bytes, 0, bytes.length); return bw; } catch (IOException err) { throw new HiveException("Error encountered while serializing bloomfilter", err); + } finally { + IOUtils.closeStream(byteStream); } } diff --git a/ql/src/java/org/apache/hadoop/hive/ql/exec/vector/expressions/aggregates/VectorUDAFBloomFilterMerge.java b/ql/src/java/org/apache/hadoop/hive/ql/exec/vector/expressions/aggregates/VectorUDAFBloomFilterMerge.java index 67a7c50..1a6d2b7 100644 --- a/ql/src/java/org/apache/hadoop/hive/ql/exec/vector/expressions/aggregates/VectorUDAFBloomFilterMerge.java +++ b/ql/src/java/org/apache/hadoop/hive/ql/exec/vector/expressions/aggregates/VectorUDAFBloomFilterMerge.java @@ -26,7 +26,6 @@ import org.apache.hadoop.hive.ql.exec.vector.VectorAggregationBufferRow; import org.apache.hadoop.hive.ql.exec.vector.VectorizedRowBatch; import org.apache.hadoop.hive.ql.exec.vector.expressions.VectorExpression; -import org.apache.hadoop.hive.ql.exec.vector.expressions.aggregates.VectorAggregateExpression.AggregationBuffer; import org.apache.hadoop.hive.ql.metadata.HiveException; import org.apache.hadoop.hive.ql.plan.AggregationDesc; import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFEvaluator; @@ -34,10 +33,10 @@ import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector; import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory; import org.apache.hadoop.io.BytesWritable; -import org.apache.hive.common.util.BloomFilter; +import org.apache.hadoop.io.IOUtils; +import org.apache.hive.common.util.BloomKFilter; public class VectorUDAFBloomFilterMerge extends VectorAggregateExpression { - private static final long serialVersionUID = 1L; private long expectedEntries = -1; @@ -53,13 +52,16 @@ byte[] bfBytes; public Aggregation(long expectedEntries) { + ByteArrayOutputStream bytesOut = null; try { - BloomFilter bf = new BloomFilter(expectedEntries); - ByteArrayOutputStream bytesOut = new ByteArrayOutputStream(); - BloomFilter.serialize(bytesOut, bf); + BloomKFilter bf = new BloomKFilter(expectedEntries); + bytesOut = new ByteArrayOutputStream(); + BloomKFilter.serialize(bytesOut, bf); bfBytes = bytesOut.toByteArray(); } catch (Exception err) { throw new IllegalArgumentException("Error creating aggregation buffer", err); + } finally { + IOUtils.closeStream(bytesOut); } } @@ -71,7 +73,7 @@ public int getVariableSize() { @Override public void reset() { // Do not change the initial bytes which contain NumHashFunctions/NumBits! - Arrays.fill(bfBytes, BloomFilter.START_OF_SERIALIZED_LONGS, bfBytes.length, (byte) 0); + Arrays.fill(bfBytes, BloomKFilter.START_OF_SERIALIZED_LONGS, bfBytes.length, (byte) 0); } } @@ -362,7 +364,7 @@ void processValue(Aggregation myagg, ColumnVector columnVector, int i) { // BloomFilter.mergeBloomFilterBytes() does a simple byte ORing // which should be faster than deserialize/merge. BytesColumnVector inputColumn = (BytesColumnVector) columnVector; - BloomFilter.mergeBloomFilterBytes(myagg.bfBytes, 0, myagg.bfBytes.length, + BloomKFilter.mergeBloomFilterBytes(myagg.bfBytes, 0, myagg.bfBytes.length, inputColumn.vector[i], inputColumn.start[i], inputColumn.length[i]); } } diff --git a/ql/src/java/org/apache/hadoop/hive/ql/udf/generic/GenericUDAFBloomFilter.java b/ql/src/java/org/apache/hadoop/hive/ql/udf/generic/GenericUDAFBloomFilter.java index 2413ae6..3d85cc4 100644 --- a/ql/src/java/org/apache/hadoop/hive/ql/udf/generic/GenericUDAFBloomFilter.java +++ b/ql/src/java/org/apache/hadoop/hive/ql/udf/generic/GenericUDAFBloomFilter.java @@ -18,6 +18,7 @@ package org.apache.hadoop.hive.ql.udf.generic; +import org.apache.hadoop.hive.common.io.NonSyncByteArrayInputStream; import org.apache.hadoop.hive.common.type.HiveDecimal; import org.apache.hadoop.hive.ql.exec.Operator; import org.apache.hadoop.hive.ql.exec.SelectOperator; @@ -28,20 +29,16 @@ import org.apache.hadoop.hive.ql.plan.ExprNodeColumnDesc; import org.apache.hadoop.hive.ql.plan.ExprNodeDescUtils; import org.apache.hadoop.hive.ql.plan.Statistics; -import org.apache.hadoop.hive.ql.plan.Statistics.State; import org.apache.hadoop.hive.serde2.io.DateWritable; import org.apache.hadoop.hive.serde2.io.HiveDecimalWritable; import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector; import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector; import org.apache.hadoop.hive.serde2.objectinspector.primitive.*; -import org.apache.hadoop.hive.serde2.typeinfo.PrimitiveTypeInfo; import org.apache.hadoop.hive.serde2.typeinfo.TypeInfo; -import org.apache.hadoop.hive.serde2.typeinfo.TypeInfoFactory; import org.apache.hadoop.io.BytesWritable; +import org.apache.hadoop.io.IOUtils; import org.apache.hadoop.io.Text; -import org.apache.hive.common.util.BloomFilter; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; +import org.apache.hive.common.util.BloomKFilter; import java.io.ByteArrayInputStream; import java.io.ByteArrayOutputStream; @@ -54,8 +51,6 @@ */ public class GenericUDAFBloomFilter implements GenericUDAFResolver2 { - private static final Logger LOG = LoggerFactory.getLogger(GenericUDAFBloomFilter.class); - @Override public GenericUDAFEvaluator getEvaluator(GenericUDAFParameterInfo info) throws SemanticException { return new GenericUDAFBloomFilterEvaluator(); @@ -106,13 +101,13 @@ public ObjectInspector init(Mode m, ObjectInspector[] parameters) throws HiveExc */ @AggregationType(estimable = true) static class BloomFilterBuf extends AbstractAggregationBuffer { - BloomFilter bloomFilter; + BloomKFilter bloomFilter; public BloomFilterBuf(long expectedEntries, long maxEntries) { if (expectedEntries > maxEntries) { - bloomFilter = new BloomFilter(1); + bloomFilter = new BloomKFilter(maxEntries); } else { - bloomFilter = new BloomFilter(expectedEntries); + bloomFilter = new BloomKFilter(expectedEntries); } } @@ -147,7 +142,7 @@ public void iterate(AggregationBuffer agg, Object[] parameters) throws HiveExcep return; } - BloomFilter bf = ((BloomFilterBuf)agg).bloomFilter; + BloomKFilter bf = ((BloomFilterBuf)agg).bloomFilter; // Add the expression into the BloomFilter switch (inputOI.getPrimitiveCategory()) { @@ -228,13 +223,15 @@ public void merge(AggregationBuffer agg, Object partial) throws HiveException { } BytesWritable bytes = (BytesWritable) partial; - ByteArrayInputStream in = new ByteArrayInputStream(bytes.getBytes()); - // Deserialze the bloomfilter + ByteArrayInputStream in = new NonSyncByteArrayInputStream(bytes.getBytes()); + // Deserialize the bloom filter try { - BloomFilter bf = BloomFilter.deserialize(in); + BloomKFilter bf = BloomKFilter.deserialize(in); ((BloomFilterBuf)agg).bloomFilter.merge(bf); } catch (IOException e) { throw new HiveException(e); + } finally { + IOUtils.closeStream(in); } } @@ -242,9 +239,11 @@ public void merge(AggregationBuffer agg, Object partial) throws HiveException { public Object terminate(AggregationBuffer agg) throws HiveException { result.reset(); try { - BloomFilter.serialize(result, ((BloomFilterBuf)agg).bloomFilter); + BloomKFilter.serialize(result, ((BloomFilterBuf)agg).bloomFilter); } catch (IOException e) { throw new HiveException(e); + } finally { + IOUtils.closeStream(result); } return new BytesWritable(result.toByteArray()); } @@ -326,6 +325,7 @@ public void setFactor(float factor) { public float getFactor() { return factor; } + @Override public String getExprString() { return "expectedEntries=" + getExpectedEntries(); diff --git a/ql/src/java/org/apache/hadoop/hive/ql/udf/generic/GenericUDFInBloomFilter.java b/ql/src/java/org/apache/hadoop/hive/ql/udf/generic/GenericUDFInBloomFilter.java index 3e6e069..786db83 100644 --- a/ql/src/java/org/apache/hadoop/hive/ql/udf/generic/GenericUDFInBloomFilter.java +++ b/ql/src/java/org/apache/hadoop/hive/ql/udf/generic/GenericUDFInBloomFilter.java @@ -18,6 +18,7 @@ package org.apache.hadoop.hive.ql.udf.generic; +import org.apache.hadoop.hive.common.io.NonSyncByteArrayInputStream; import org.apache.hadoop.hive.common.type.HiveDecimal; import org.apache.hadoop.hive.ql.exec.UDFArgumentException; import org.apache.hadoop.hive.ql.exec.UDFArgumentLengthException; @@ -32,13 +33,13 @@ import org.apache.hadoop.hive.serde2.objectinspector.primitive.*; import org.apache.hadoop.hive.serde2.typeinfo.PrimitiveTypeInfo; import org.apache.hadoop.io.BytesWritable; +import org.apache.hadoop.io.IOUtils; import org.apache.hadoop.io.Text; -import org.apache.hive.common.util.BloomFilter; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; +import org.apache.hive.common.util.BloomKFilter; import java.io.ByteArrayInputStream; import java.io.IOException; +import java.io.InputStream; import java.sql.Timestamp; /** @@ -46,11 +47,10 @@ */ @VectorizedExpressions({VectorInBloomFilterColDynamicValue.class}) public class GenericUDFInBloomFilter extends GenericUDF { - private static final Logger LOG = LoggerFactory.getLogger(GenericUDFInBloomFilter.class); private transient ObjectInspector valObjectInspector; private transient ObjectInspector bloomFilterObjectInspector; - private transient BloomFilter bloomFilter; + private transient BloomKFilter bloomFilter; private transient boolean initializedBloomFilter; private transient byte[] scratchBuffer = new byte[HiveDecimal.SCRATCH_BUFFER_LEN_TO_BYTES]; @@ -95,13 +95,17 @@ public Object evaluate(DeferredObject[] arguments) throws HiveException { if (!initializedBloomFilter) { // Setup the bloom filter once + InputStream in = null; try { BytesWritable bw = (BytesWritable) arguments[1].get(); byte[] bytes = new byte[bw.getLength()]; System.arraycopy(bw.getBytes(), 0, bytes, 0, bw.getLength()); - bloomFilter = BloomFilter.deserialize(new ByteArrayInputStream(bytes)); + in = new NonSyncByteArrayInputStream(bytes); + bloomFilter = BloomKFilter.deserialize(in); } catch ( IOException e) { throw new HiveException(e); + } finally { + IOUtils.closeStream(in); } initializedBloomFilter = true; } diff --git a/storage-api/src/java/org/apache/hive/common/util/BloomFilter.java b/storage-api/src/java/org/apache/hive/common/util/BloomFilter.java index e9f419d..706b834 100644 --- a/storage-api/src/java/org/apache/hive/common/util/BloomFilter.java +++ b/storage-api/src/java/org/apache/hive/common/util/BloomFilter.java @@ -19,9 +19,7 @@ package org.apache.hive.common.util; import java.io.*; -import java.util.ArrayList; import java.util.Arrays; -import java.util.List; /** * BloomFilter is a probabilistic data structure for set membership check. BloomFilters are @@ -72,17 +70,15 @@ public BloomFilter(long expectedEntries, double fpp) { /** * A constructor to support rebuilding the BloomFilter from a serialized representation. - * @param bits - * @param numBits - * @param numFuncs + * @param bits - bits are used as such for bitset and are NOT copied, any changes to bits will affect bloom filter + * @param numFuncs - number of hash functions */ - public BloomFilter(List bits, int numBits, int numFuncs) { + public BloomFilter(long[] bits, int numFuncs) { super(); - long[] copied = new long[bits.size()]; - for (int i = 0; i < bits.size(); i++) copied[i] = bits.get(i); - bitSet = new BitSet(copied); - this.numBits = numBits; - numHashFunctions = numFuncs; + // input long[] is set as such without copying, so any modification to the source will affect bloom filter + this.bitSet = new BitSet(bits); + this.numBits = bits.length * Long.SIZE; + this.numHashFunctions = numFuncs; } static int optimalNumOfHashFunctions(long n, long m) { @@ -118,7 +114,7 @@ private void addHash(long hash64) { int hash2 = (int) (hash64 >>> 32); for (int i = 1; i <= numHashFunctions; i++) { - int combinedHash = hash1 + (i * hash2); + int combinedHash = hash1 + ((i + 1) * hash2); // hashcode should be positive, flip all the bits if it's negative if (combinedHash < 0) { combinedHash = ~combinedHash; @@ -162,7 +158,7 @@ private boolean testHash(long hash64) { int hash2 = (int) (hash64 >>> 32); for (int i = 1; i <= numHashFunctions; i++) { - int combinedHash = hash1 + (i * hash2); + int combinedHash = hash1 + ((i + 1) * hash2); // hashcode should be positive, flip all the bits if it's negative if (combinedHash < 0) { combinedHash = ~combinedHash; @@ -253,11 +249,11 @@ public static void serialize(OutputStream out, BloomFilter bloomFilter) throws I * Serialized BloomFilter format: * 1 byte for the number of hash functions. * 1 big endian int(That is how OutputStream works) for the number of longs in the bitset - * big endina longs in the BloomFilter bitset + * big endian longs in the BloomFilter bitset */ DataOutputStream dataOutputStream = new DataOutputStream(out); dataOutputStream.writeByte(bloomFilter.numHashFunctions); - dataOutputStream.writeInt(bloomFilter.numBits); + dataOutputStream.writeInt(bloomFilter.getBitSet().length); for (long value : bloomFilter.getBitSet()) { dataOutputStream.writeLong(value); } @@ -278,13 +274,12 @@ public static BloomFilter deserialize(InputStream in) throws IOException { try { DataInputStream dataInputStream = new DataInputStream(in); int numHashFunc = dataInputStream.readByte(); - int numBits = dataInputStream.readInt(); - int sz = (numBits/Long.SIZE); - List data = new ArrayList(); - for (int i = 0; i < sz; i++) { - data.add(dataInputStream.readLong()); + int numLongs = dataInputStream.readInt(); + long[] data = new long[numLongs]; + for (int i = 0; i < numLongs; i++) { + data[i] = dataInputStream.readLong(); } - return new BloomFilter(data, numBits, numHashFunc); + return new BloomFilter(data, numHashFunc); } catch (RuntimeException e) { IOException io = new IOException( "Unable to deserialize BloomFilter"); io.initCause(e); diff --git a/storage-api/src/java/org/apache/hive/common/util/BloomKFilter.java b/storage-api/src/java/org/apache/hive/common/util/BloomKFilter.java new file mode 100644 index 0000000..45326ab --- /dev/null +++ b/storage-api/src/java/org/apache/hive/common/util/BloomKFilter.java @@ -0,0 +1,472 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.hive.common.util; + +import java.io.DataInputStream; +import java.io.DataOutputStream; +import java.io.IOException; +import java.io.InputStream; +import java.io.OutputStream; +import java.util.Arrays; + +/** + * BloomKFilter is variation of {@link BloomFilter}. Unlike BloomFilter, BloomKFilter will spread + * 'k' hash bits within same cache line for better L1 cache performance. The way it works is, + * First hash code is computed from key which is used to locate the block offset (n-longs in bitset constitute a block) + * Subsequent 'k' hash codes are used to spread hash bits within the block. By default block size is chosen as 8, + * which is to match cache line size (8 longs = 64 bytes = cache line size). + * Refer {@link BloomKFilter#addBytes(byte[])} for more info. + * + * This implementation has much lesser L1 data cache misses than {@link BloomFilter}. + */ +public class BloomKFilter { + private byte[] BYTE_ARRAY_4 = new byte[4]; + private byte[] BYTE_ARRAY_8 = new byte[8]; + public static final float DEFAULT_FPP = 0.05f; + private static final int DEFAULT_BLOCK_SIZE = 8; + private static final int DEFAULT_BLOCK_SIZE_BITS = (int) (Math.log(DEFAULT_BLOCK_SIZE) / Math.log(2)); + private static final int DEFAULT_BLOCK_OFFSET_MASK = DEFAULT_BLOCK_SIZE - 1; + private static final int DEFAULT_BIT_OFFSET_MASK = Long.SIZE - 1; + private final long[] masks = new long[DEFAULT_BLOCK_SIZE]; + private BitSet bitSet; + private final int m; + private final int k; + // spread k-1 bits to adjacent longs, default is 8 + // spreading hash bits within blockSize * longs will make bloom filter L1 cache friendly + // default block size is set to 8 as most cache line sizes are 64 bytes and also AVX512 friendly + private final int totalBlockCount; + + static void checkArgument(boolean expression, String message) { + if (!expression) { + throw new IllegalArgumentException(message); + } + } + + public BloomKFilter(long maxNumEntries) { + checkArgument(maxNumEntries > 0, "expectedEntries should be > 0"); + long numBits = optimalNumOfBits(maxNumEntries, DEFAULT_FPP); + this.k = optimalNumOfHashFunctions(maxNumEntries, numBits); + int nLongs = (int) Math.ceil((double) numBits / (double) Long.SIZE); + // additional bits to pad long array to block size + int padLongs = DEFAULT_BLOCK_SIZE - nLongs % DEFAULT_BLOCK_SIZE; + this.m = (nLongs + padLongs) * Long.SIZE; + this.bitSet = new BitSet(m); + checkArgument((bitSet.data.length % DEFAULT_BLOCK_SIZE) == 0, "bitSet has to be block aligned"); + this.totalBlockCount = bitSet.data.length / DEFAULT_BLOCK_SIZE; + } + + /** + * A constructor to support rebuilding the BloomFilter from a serialized representation. + * @param bits + * @param numFuncs + */ + public BloomKFilter(long[] bits, int numFuncs) { + super(); + bitSet = new BitSet(bits); + this.m = bits.length * Long.SIZE; + this.k = numFuncs; + checkArgument((bitSet.data.length % DEFAULT_BLOCK_SIZE) == 0, "bitSet has to be block aligned"); + this.totalBlockCount = bitSet.data.length / DEFAULT_BLOCK_SIZE; + } + static int optimalNumOfHashFunctions(long n, long m) { + return Math.max(1, (int) Math.round((double) m / n * Math.log(2))); + } + + static long optimalNumOfBits(long n, double p) { + return (long) (-n * Math.log(p) / (Math.log(2) * Math.log(2))); + } + + public void add(byte[] val) { + addBytes(val); + } + + public void addBytes(byte[] val, int offset, int length) { + // We use the trick mentioned in "Less Hashing, Same Performance: Building a Better Bloom Filter" + // by Kirsch et.al. From abstract 'only two hash functions are necessary to effectively + // implement a Bloom filter without any loss in the asymptotic false positive probability' + + // Lets split up 64-bit hashcode into two 32-bit hash codes and employ the technique mentioned + // in the above paper + long hash64 = val == null ? Murmur3.NULL_HASHCODE : + Murmur3.hash64(val, offset, length); + addHash(hash64); + } + + public void addBytes(byte[] val) { + addBytes(val, 0, val.length); + } + + private void addHash(long hash64) { + final int hash1 = (int) hash64; + final int hash2 = (int) (hash64 >>> 32); + + int firstHash = hash1 + hash2; + // hashcode should be positive, flip all the bits if it's negative + if (firstHash < 0) { + firstHash = ~firstHash; + } + + // first hash is used to locate start of the block (blockBaseOffset) + // subsequent K hashes are used to generate K bits within a block of words + final int blockIdx = firstHash % totalBlockCount; + final int blockBaseOffset = blockIdx << DEFAULT_BLOCK_SIZE_BITS; + for (int i = 1; i <= k; i++) { + int combinedHash = hash1 + ((i + 1) * hash2); + // hashcode should be positive, flip all the bits if it's negative + if (combinedHash < 0) { + combinedHash = ~combinedHash; + } + // LSB 3 bits is used to locate offset within the block + final int absOffset = blockBaseOffset + (combinedHash & DEFAULT_BLOCK_OFFSET_MASK); + // Next 6 bits are used to locate offset within a long/word + final int bitPos = (combinedHash >>> DEFAULT_BLOCK_SIZE_BITS) & DEFAULT_BIT_OFFSET_MASK; + bitSet.data[absOffset] |= (1L << bitPos); + } + } + + public void addString(String val) { + addBytes(val.getBytes()); + } + + public void addByte(byte val) { + addBytes(new byte[]{val}); + } + + public void addInt(int val) { + // puts int in little endian order + addBytes(intToByteArrayLE(val)); + } + + + public void addLong(long val) { + // puts long in little endian order + addBytes(longToByteArrayLE(val)); + } + + public void addFloat(float val) { + addInt(Float.floatToIntBits(val)); + } + + public void addDouble(double val) { + addLong(Double.doubleToLongBits(val)); + } + + public boolean test(byte[] val) { + return testBytes(val); + } + + public boolean testBytes(byte[] val) { + return testBytes(val, 0, val.length); + } + + public boolean testBytes(byte[] val, int offset, int length) { + long hash64 = val == null ? Murmur3.NULL_HASHCODE : + Murmur3.hash64(val, offset, length); + return testHash(hash64); + } + + private boolean testHash(long hash64) { + final int hash1 = (int) hash64; + final int hash2 = (int) (hash64 >>> 32); + + int firstHash = hash1 + hash2; + // hashcode should be positive, flip all the bits if it's negative + if (firstHash < 0) { + firstHash = ~firstHash; + } + + // first hash is used to locate start of the block (blockBaseOffset) + // subsequent K hashes are used to generate K bits within a block of words + // To avoid branches during probe, a separate masks array is used for each longs/words within a block. + // data array and masks array are then traversed together and checked for corresponding set bits. + final int blockIdx = firstHash % totalBlockCount; + final int blockBaseOffset = blockIdx << DEFAULT_BLOCK_SIZE_BITS; + + // iterate and update masks array + for (int i = 1; i <= k; i++) { + int combinedHash = hash1 + ((i + 1) * hash2); + // hashcode should be positive, flip all the bits if it's negative + if (combinedHash < 0) { + combinedHash = ~combinedHash; + } + // LSB 3 bits is used to locate offset within the block + final int wordOffset = combinedHash & DEFAULT_BLOCK_OFFSET_MASK; + // Next 6 bits are used to locate offset within a long/word + final int bitPos = (combinedHash >>> DEFAULT_BLOCK_SIZE_BITS) & DEFAULT_BIT_OFFSET_MASK; + masks[wordOffset] |= (1L << bitPos); + } + + // traverse data and masks array together, check for set bits + long expected = 0; + for (int i = 0; i < DEFAULT_BLOCK_SIZE; i++) { + final long mask = masks[i]; + expected |= (bitSet.data[blockBaseOffset + i] & mask) ^ mask; + } + + // clear the mask for array reuse (this is to avoid masks array allocation in inner loop) + Arrays.fill(masks, 0); + + // if all bits are set, expected should be 0 + return expected == 0; + } + + public boolean testString(String val) { + return testBytes(val.getBytes()); + } + + public boolean testByte(byte val) { + return testBytes(new byte[]{val}); + } + + public boolean testInt(int val) { + return testBytes(intToByteArrayLE(val)); + } + + public boolean testLong(long val) { + return testBytes(longToByteArrayLE(val)); + } + + public boolean testFloat(float val) { + return testInt(Float.floatToIntBits(val)); + } + + public boolean testDouble(double val) { + return testLong(Double.doubleToLongBits(val)); + } + + private byte[] intToByteArrayLE(int val) { + BYTE_ARRAY_4[0] = (byte) (val >> 0); + BYTE_ARRAY_4[1] = (byte) (val >> 8); + BYTE_ARRAY_4[2] = (byte) (val >> 16); + BYTE_ARRAY_4[3] = (byte) (val >> 24); + return BYTE_ARRAY_4; + } + + private byte[] longToByteArrayLE(long val) { + BYTE_ARRAY_8[0] = (byte) (val >> 0); + BYTE_ARRAY_8[1] = (byte) (val >> 8); + BYTE_ARRAY_8[2] = (byte) (val >> 16); + BYTE_ARRAY_8[3] = (byte) (val >> 24); + BYTE_ARRAY_8[4] = (byte) (val >> 32); + BYTE_ARRAY_8[5] = (byte) (val >> 40); + BYTE_ARRAY_8[6] = (byte) (val >> 48); + BYTE_ARRAY_8[7] = (byte) (val >> 56); + return BYTE_ARRAY_8; + } + + public long sizeInBytes() { + return getBitSize() / 8; + } + + public int getBitSize() { + return bitSet.getData().length * Long.SIZE; + } + + public int getNumHashFunctions() { + return k; + } + + public int getNumBits() { + return m; + } + + public long[] getBitSet() { + return bitSet.getData(); + } + + @Override + public String toString() { + return "m: " + m + " k: " + k; + } + + /** + * Merge the specified bloom filter with current bloom filter. + * + * @param that - bloom filter to merge + */ + public void merge(BloomKFilter that) { + if (this != that && this.m == that.m && this.k == that.k) { + this.bitSet.putAll(that.bitSet); + } else { + throw new IllegalArgumentException("BloomKFilters are not compatible for merging." + + " this - " + this.toString() + " that - " + that.toString()); + } + } + + public void reset() { + this.bitSet.clear(); + } + + /** + * Serialize a bloom filter + * + * @param out output stream to write to + * @param bloomFilter BloomKFilter that needs to be seralized + */ + public static void serialize(OutputStream out, BloomKFilter bloomFilter) throws IOException { + /** + * Serialized BloomKFilter format: + * 1 byte for the number of hash functions. + * 1 big endian int(That is how OutputStream works) for the number of longs in the bitset + * big endina longs in the BloomKFilter bitset + */ + DataOutputStream dataOutputStream = new DataOutputStream(out); + dataOutputStream.writeByte(bloomFilter.k); + dataOutputStream.writeInt(bloomFilter.getBitSet().length); + for (long value : bloomFilter.getBitSet()) { + dataOutputStream.writeLong(value); + } + } + + /** + * Deserialize a bloom filter + * Read a byte stream, which was written by {@linkplain #serialize(OutputStream, BloomKFilter)} + * into a {@code BloomKFilter} + * + * @param in input bytestream + * @return deserialized BloomKFilter + */ + public static BloomKFilter deserialize(InputStream in) throws IOException { + if (in == null) { + throw new IOException("Input stream is null"); + } + + try { + DataInputStream dataInputStream = new DataInputStream(in); + int numHashFunc = dataInputStream.readByte(); + int bitsetArrayLen = dataInputStream.readInt(); + long[] data = new long[bitsetArrayLen]; + for (int i = 0; i < bitsetArrayLen; i++) { + data[i] = dataInputStream.readLong(); + } + return new BloomKFilter(data, numHashFunc); + } catch (RuntimeException e) { + IOException io = new IOException("Unable to deserialize BloomKFilter"); + io.initCause(e); + throw io; + } + } + + // Given a byte array consisting of a serialized BloomKFilter, gives the offset (from 0) + // for the start of the serialized long values that make up the bitset. + // NumHashFunctions (1 byte) + bitset array length (4 bytes) + public static final int START_OF_SERIALIZED_LONGS = 5; + + /** + * Merges BloomKFilter bf2 into bf1. + * Assumes 2 BloomKFilters with the same size/hash functions are serialized to byte arrays + * + * @param bf1Bytes + * @param bf1Start + * @param bf1Length + * @param bf2Bytes + * @param bf2Start + * @param bf2Length + */ + public static void mergeBloomFilterBytes( + byte[] bf1Bytes, int bf1Start, int bf1Length, + byte[] bf2Bytes, int bf2Start, int bf2Length) { + if (bf1Length != bf2Length) { + throw new IllegalArgumentException("bf1Length " + bf1Length + " does not match bf2Length " + bf2Length); + } + + // Validation on the bitset size/3 hash functions. + for (int idx = 0; idx < START_OF_SERIALIZED_LONGS; ++idx) { + if (bf1Bytes[bf1Start + idx] != bf2Bytes[bf2Start + idx]) { + throw new IllegalArgumentException("bf1 NumHashFunctions/NumBits does not match bf2"); + } + } + + // Just bitwise-OR the bits together - size/# functions should be the same, + // rest of the data is serialized long values for the bitset which are supposed to be bitwise-ORed. + for (int idx = START_OF_SERIALIZED_LONGS; idx < bf1Length; ++idx) { + bf1Bytes[bf1Start + idx] |= bf2Bytes[bf2Start + idx]; + } + } + + /** + * Bare metal bit set implementation. For performance reasons, this implementation does not check + * for index bounds nor expand the bit set size if the specified index is greater than the size. + */ + public static class BitSet { + private final long[] data; + + public BitSet(long bits) { + this(new long[(int) Math.ceil((double) bits / (double) Long.SIZE)]); + } + + /** + * Deserialize long array as bit set. + * + * @param data - bit array + */ + public BitSet(long[] data) { + assert data.length > 0 : "data length is zero!"; + this.data = data; + } + + /** + * Sets the bit at specified index. + * + * @param index - position + */ + public void set(int index) { + data[index >>> 6] |= (1L << index); + } + + /** + * Returns true if the bit is set in the specified index. + * + * @param index - position + * @return - value at the bit position + */ + public boolean get(int index) { + return (data[index >>> 6] & (1L << index)) != 0; + } + + /** + * Number of bits + */ + public int bitSize() { + return data.length * Long.SIZE; + } + + public long[] getData() { + return data; + } + + /** + * Combines the two BitArrays using bitwise OR. + */ + public void putAll(BloomKFilter.BitSet array) { + assert data.length == array.data.length : + "BitArrays must be of equal length (" + data.length + "!= " + array.data.length + ")"; + for (int i = 0; i < data.length; i++) { + data[i] |= array.data[i]; + } + } + + /** + * Clear the bit set. + */ + public void clear() { + Arrays.fill(data, 0); + } + } +} diff --git a/storage-api/src/test/org/apache/hive/common/util/TestBloomFilter.java b/storage-api/src/test/org/apache/hive/common/util/TestBloomFilter.java index e4ee93a..cd1fa08 100644 --- a/storage-api/src/test/org/apache/hive/common/util/TestBloomFilter.java +++ b/storage-api/src/test/org/apache/hive/common/util/TestBloomFilter.java @@ -19,6 +19,7 @@ package org.apache.hive.common.util; import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertTrue; import java.io.ByteArrayInputStream; import java.io.ByteArrayOutputStream; @@ -547,11 +548,8 @@ public void testMergeBloomFilterBytesFailureCases() throws Exception { BloomFilter bf1 = new BloomFilter(1000); BloomFilter bf2 = new BloomFilter(200); // Create bloom filter with same number of bits, but different # hash functions - ArrayList bits = new ArrayList(); - for (int idx = 0; idx < bf1.getBitSet().length; ++idx) { - bits.add(0L); - } - BloomFilter bf3 = new BloomFilter(bits, bf1.getBitSize(), bf1.getNumHashFunctions() + 1); + long[] bits = new long[bf1.getBitSet().length]; + BloomFilter bf3 = new BloomFilter(bits, bf1.getNumHashFunctions() + 1); // Serialize to bytes ByteArrayOutputStream bytesOut = new ByteArrayOutputStream(); @@ -586,4 +584,132 @@ public void testMergeBloomFilterBytesFailureCases() throws Exception { // expected } } + + @Test + public void testFpp1K() { + int size = 1000; + BloomFilter bf = new BloomFilter(size); + int fp = 0; + for (int i = 0; i < size; i++) { + bf.addLong(i); + } + + for (int i = 0; i < size; i++) { + assertTrue(bf.testLong(i)); + } + + for (int i = 0; i < size; i++) { + int probe = rand.nextInt(); + // out of range probes + if ((probe > size) || (probe < 0)) { + if (bf.testLong(probe)) { + fp++; + } + } + } + + double actualFpp = (double) fp / (double) size; + double expectedFpp = bf.DEFAULT_FPP; + if (actualFpp < expectedFpp) { + assertTrue(actualFpp != 0.0); + } else { + assertEquals(expectedFpp, actualFpp, 0.005); + } + } + + @Test + public void testFpp10K() { + int size = 10_000; + BloomFilter bf = new BloomFilter(size); + int fp = 0; + for (int i = 0; i < size; i++) { + bf.addLong(i); + } + + for (int i = 0; i < size; i++) { + assertTrue(bf.testLong(i)); + } + + for (int i = 0; i < size; i++) { + int probe = rand.nextInt(); + // out of range probes + if ((probe > size) || (probe < 0)) { + if (bf.testLong(probe)) { + fp++; + } + } + } + + double actualFpp = (double) fp / (double) size; + double expectedFpp = bf.DEFAULT_FPP; + if (actualFpp < expectedFpp) { + assertTrue(actualFpp != 0.0); + } else { + assertEquals(expectedFpp, actualFpp, 0.005); + } + } + + @Test + public void testFpp1M() { + int size = 1_000_000; + BloomFilter bf = new BloomFilter(size); + int fp = 0; + for (int i = 0; i < size; i++) { + bf.addLong(i); + } + + for (int i = 0; i < size; i++) { + assertTrue(bf.testLong(i)); + } + + for (int i = 0; i < size; i++) { + int probe = rand.nextInt(); + // out of range probes + if ((probe > size) || (probe < 0)) { + if (bf.testLong(probe)) { + fp++; + } + } + } + + double actualFpp = (double) fp / (double) size; + double expectedFpp = bf.DEFAULT_FPP; + if (actualFpp < expectedFpp) { + assertTrue(actualFpp != 0.0); + } else { + assertEquals(expectedFpp, actualFpp, 0.005); + } + } + + @Test + public void testFpp10M() { + int size = 10_000_000; + BloomFilter bf = new BloomFilter(size); + int fp = 0; + for (int i = 0; i < size; i++) { + bf.addLong(i); + } + + for (int i = 0; i < size; i++) { + assertTrue(bf.testLong(i)); + } + + for (int i = 0; i < size; i++) { + int probe = rand.nextInt(); + // out of range probes + if ((probe > size) || (probe < 0)) { + if (bf.testLong(probe)) { + fp++; + } + } + } + + double actualFpp = (double) fp / (double) size; + double expectedFpp = bf.DEFAULT_FPP; + if (actualFpp < expectedFpp) { + assertTrue(actualFpp != 0.0); + } else { + assertEquals(expectedFpp, actualFpp, 0.005); + } + } } diff --git a/storage-api/src/test/org/apache/hive/common/util/TestBloomKFilter.java b/storage-api/src/test/org/apache/hive/common/util/TestBloomKFilter.java new file mode 100644 index 0000000..159fab2 --- /dev/null +++ b/storage-api/src/test/org/apache/hive/common/util/TestBloomKFilter.java @@ -0,0 +1,699 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.hive.common.util; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertTrue; + +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; +import java.util.Random; + +import org.junit.Assert; +import org.junit.Test; + +/** + * + */ +public class TestBloomKFilter { + private static final int COUNT = 100; + Random rand = new Random(123); + // bloom-1 is known to have higher fpp, to make tests pass give room for another 3% + private final double deltaError = 0.03; + + @Test(expected = IllegalArgumentException.class) + public void testBloomIllegalArg1() { + BloomKFilter bf = new BloomKFilter(0); + } + + @Test + public void testBloomNumBits() { + assertEquals(0, BloomKFilter.optimalNumOfBits(0, 0)); + assertEquals(0, BloomKFilter.optimalNumOfBits(0, 1)); + assertEquals(0, BloomKFilter.optimalNumOfBits(1, 1)); + assertEquals(7, BloomKFilter.optimalNumOfBits(1, 0.03)); + assertEquals(72, BloomKFilter.optimalNumOfBits(10, 0.03)); + assertEquals(729, BloomKFilter.optimalNumOfBits(100, 0.03)); + assertEquals(7298, BloomKFilter.optimalNumOfBits(1000, 0.03)); + assertEquals(72984, BloomKFilter.optimalNumOfBits(10000, 0.03)); + assertEquals(729844, BloomKFilter.optimalNumOfBits(100000, 0.03)); + assertEquals(7298440, BloomKFilter.optimalNumOfBits(1000000, 0.03)); + assertEquals(6235224, BloomKFilter.optimalNumOfBits(1000000, 0.05)); + assertEquals(1870567268, BloomKFilter.optimalNumOfBits(300000000, 0.05)); + assertEquals(1437758756, BloomKFilter.optimalNumOfBits(300000000, 0.1)); + assertEquals(432808512, BloomKFilter.optimalNumOfBits(300000000, 0.5)); + assertEquals(1393332198, BloomKFilter.optimalNumOfBits(3000000000L, 0.8)); + assertEquals(657882327, BloomKFilter.optimalNumOfBits(3000000000L, 0.9)); + assertEquals(0, BloomKFilter.optimalNumOfBits(3000000000L, 1)); + + BloomKFilter bloomKFilter = new BloomKFilter(40); + assertEquals(8, bloomKFilter.getBitSet().length); + assertEquals(bloomKFilter.getNumBits(), bloomKFilter.getBitSize()); + } + + @Test + public void testBloomNumHashFunctions() { + assertEquals(1, BloomKFilter.optimalNumOfHashFunctions(-1, -1)); + assertEquals(1, BloomKFilter.optimalNumOfHashFunctions(0, 0)); + assertEquals(1, BloomKFilter.optimalNumOfHashFunctions(10, 0)); + assertEquals(1, BloomKFilter.optimalNumOfHashFunctions(10, 10)); + assertEquals(7, BloomKFilter.optimalNumOfHashFunctions(10, 100)); + assertEquals(1, BloomKFilter.optimalNumOfHashFunctions(100, 100)); + assertEquals(1, BloomKFilter.optimalNumOfHashFunctions(1000, 100)); + assertEquals(1, BloomKFilter.optimalNumOfHashFunctions(10000, 100)); + assertEquals(1, BloomKFilter.optimalNumOfHashFunctions(100000, 100)); + assertEquals(1, BloomKFilter.optimalNumOfHashFunctions(1000000, 100)); + } + + @Test + public void testBloomKFilterBytes() { + BloomKFilter bf = new BloomKFilter(10000); + byte[] val = new byte[]{1, 2, 3}; + byte[] val1 = new byte[]{1, 2, 3, 4}; + byte[] val2 = new byte[]{1, 2, 3, 4, 5}; + byte[] val3 = new byte[]{1, 2, 3, 4, 5, 6}; + + assertEquals(false, bf.test(val)); + assertEquals(false, bf.test(val1)); + assertEquals(false, bf.test(val2)); + assertEquals(false, bf.test(val3)); + bf.add(val); + assertEquals(true, bf.test(val)); + assertEquals(false, bf.test(val1)); + assertEquals(false, bf.test(val2)); + assertEquals(false, bf.test(val3)); + bf.add(val1); + assertEquals(true, bf.test(val)); + assertEquals(true, bf.test(val1)); + assertEquals(false, bf.test(val2)); + assertEquals(false, bf.test(val3)); + bf.add(val2); + assertEquals(true, bf.test(val)); + assertEquals(true, bf.test(val1)); + assertEquals(true, bf.test(val2)); + assertEquals(false, bf.test(val3)); + bf.add(val3); + assertEquals(true, bf.test(val)); + assertEquals(true, bf.test(val1)); + assertEquals(true, bf.test(val2)); + assertEquals(true, bf.test(val3)); + + byte[] randVal = new byte[COUNT]; + for (int i = 0; i < COUNT; i++) { + rand.nextBytes(randVal); + bf.add(randVal); + } + // last value should be present + assertEquals(true, bf.test(randVal)); + // most likely this value should not exist + randVal[0] = 0; + randVal[1] = 0; + randVal[2] = 0; + randVal[3] = 0; + randVal[4] = 0; + assertEquals(false, bf.test(randVal)); + + assertEquals(7808, bf.sizeInBytes()); + } + + @Test + public void testBloomKFilterByte() { + BloomKFilter bf = new BloomKFilter(10000); + byte val = Byte.MIN_VALUE; + byte val1 = 1; + byte val2 = 2; + byte val3 = Byte.MAX_VALUE; + + assertEquals(false, bf.testLong(val)); + assertEquals(false, bf.testLong(val1)); + assertEquals(false, bf.testLong(val2)); + assertEquals(false, bf.testLong(val3)); + bf.addLong(val); + assertEquals(true, bf.testLong(val)); + assertEquals(false, bf.testLong(val1)); + assertEquals(false, bf.testLong(val2)); + assertEquals(false, bf.testLong(val3)); + bf.addLong(val1); + assertEquals(true, bf.testLong(val)); + assertEquals(true, bf.testLong(val1)); + assertEquals(false, bf.testLong(val2)); + assertEquals(false, bf.testLong(val3)); + bf.addLong(val2); + assertEquals(true, bf.testLong(val)); + assertEquals(true, bf.testLong(val1)); + assertEquals(true, bf.testLong(val2)); + assertEquals(false, bf.testLong(val3)); + bf.addLong(val3); + assertEquals(true, bf.testLong(val)); + assertEquals(true, bf.testLong(val1)); + assertEquals(true, bf.testLong(val2)); + assertEquals(true, bf.testLong(val3)); + + byte randVal = 0; + for (int i = 0; i < COUNT; i++) { + randVal = (byte) rand.nextInt(Byte.MAX_VALUE); + bf.addLong(randVal); + } + // last value should be present + assertEquals(true, bf.testLong(randVal)); + // most likely this value should not exist + assertEquals(false, bf.testLong((byte) -120)); + + assertEquals(7808, bf.sizeInBytes()); + } + + @Test + public void testBloomKFilterInt() { + BloomKFilter bf = new BloomKFilter(10000); + int val = Integer.MIN_VALUE; + int val1 = 1; + int val2 = 2; + int val3 = Integer.MAX_VALUE; + + assertEquals(false, bf.testLong(val)); + assertEquals(false, bf.testLong(val1)); + assertEquals(false, bf.testLong(val2)); + assertEquals(false, bf.testLong(val3)); + bf.addLong(val); + assertEquals(true, bf.testLong(val)); + assertEquals(false, bf.testLong(val1)); + assertEquals(false, bf.testLong(val2)); + assertEquals(false, bf.testLong(val3)); + bf.addLong(val1); + assertEquals(true, bf.testLong(val)); + assertEquals(true, bf.testLong(val1)); + assertEquals(false, bf.testLong(val2)); + assertEquals(false, bf.testLong(val3)); + bf.addLong(val2); + assertEquals(true, bf.testLong(val)); + assertEquals(true, bf.testLong(val1)); + assertEquals(true, bf.testLong(val2)); + assertEquals(false, bf.testLong(val3)); + bf.addLong(val3); + assertEquals(true, bf.testLong(val)); + assertEquals(true, bf.testLong(val1)); + assertEquals(true, bf.testLong(val2)); + assertEquals(true, bf.testLong(val3)); + + int randVal = 0; + for (int i = 0; i < COUNT; i++) { + randVal = rand.nextInt(); + bf.addLong(randVal); + } + // last value should be present + assertEquals(true, bf.testLong(randVal)); + // most likely this value should not exist + assertEquals(false, bf.testLong(-120)); + + assertEquals(7808, bf.sizeInBytes()); + } + + @Test + public void testBloomKFilterLong() { + BloomKFilter bf = new BloomKFilter(10000); + long val = Long.MIN_VALUE; + long val1 = 1; + long val2 = 2; + long val3 = Long.MAX_VALUE; + + assertEquals(false, bf.testLong(val)); + assertEquals(false, bf.testLong(val1)); + assertEquals(false, bf.testLong(val2)); + assertEquals(false, bf.testLong(val3)); + bf.addLong(val); + assertEquals(true, bf.testLong(val)); + assertEquals(false, bf.testLong(val1)); + assertEquals(false, bf.testLong(val2)); + assertEquals(false, bf.testLong(val3)); + bf.addLong(val1); + assertEquals(true, bf.testLong(val)); + assertEquals(true, bf.testLong(val1)); + assertEquals(false, bf.testLong(val2)); + assertEquals(false, bf.testLong(val3)); + bf.addLong(val2); + assertEquals(true, bf.testLong(val)); + assertEquals(true, bf.testLong(val1)); + assertEquals(true, bf.testLong(val2)); + assertEquals(false, bf.testLong(val3)); + bf.addLong(val3); + assertEquals(true, bf.testLong(val)); + assertEquals(true, bf.testLong(val1)); + assertEquals(true, bf.testLong(val2)); + assertEquals(true, bf.testLong(val3)); + + long randVal = 0; + for (int i = 0; i < COUNT; i++) { + randVal = rand.nextLong(); + bf.addLong(randVal); + } + // last value should be present + assertEquals(true, bf.testLong(randVal)); + // most likely this value should not exist + assertEquals(false, bf.testLong(-120)); + + assertEquals(7808, bf.sizeInBytes()); + } + + @Test + public void testBloomKFilterFloat() { + BloomKFilter bf = new BloomKFilter(10000); + float val = Float.MIN_VALUE; + float val1 = 1.1f; + float val2 = 2.2f; + float val3 = Float.MAX_VALUE; + + assertEquals(false, bf.testDouble(val)); + assertEquals(false, bf.testDouble(val1)); + assertEquals(false, bf.testDouble(val2)); + assertEquals(false, bf.testDouble(val3)); + bf.addDouble(val); + assertEquals(true, bf.testDouble(val)); + assertEquals(false, bf.testDouble(val1)); + assertEquals(false, bf.testDouble(val2)); + assertEquals(false, bf.testDouble(val3)); + bf.addDouble(val1); + assertEquals(true, bf.testDouble(val)); + assertEquals(true, bf.testDouble(val1)); + assertEquals(false, bf.testDouble(val2)); + assertEquals(false, bf.testDouble(val3)); + bf.addDouble(val2); + assertEquals(true, bf.testDouble(val)); + assertEquals(true, bf.testDouble(val1)); + assertEquals(true, bf.testDouble(val2)); + assertEquals(false, bf.testDouble(val3)); + bf.addDouble(val3); + assertEquals(true, bf.testDouble(val)); + assertEquals(true, bf.testDouble(val1)); + assertEquals(true, bf.testDouble(val2)); + assertEquals(true, bf.testDouble(val3)); + + float randVal = 0; + for (int i = 0; i < COUNT; i++) { + randVal = rand.nextFloat(); + bf.addDouble(randVal); + } + // last value should be present + assertEquals(true, bf.testDouble(randVal)); + // most likely this value should not exist + assertEquals(false, bf.testDouble(-120.2f)); + + assertEquals(7808, bf.sizeInBytes()); + } + + @Test + public void testBloomKFilterDouble() { + BloomKFilter bf = new BloomKFilter(10000); + double val = Double.MIN_VALUE; + double val1 = 1.1d; + double val2 = 2.2d; + double val3 = Double.MAX_VALUE; + + assertEquals(false, bf.testDouble(val)); + assertEquals(false, bf.testDouble(val1)); + assertEquals(false, bf.testDouble(val2)); + assertEquals(false, bf.testDouble(val3)); + bf.addDouble(val); + assertEquals(true, bf.testDouble(val)); + assertEquals(false, bf.testDouble(val1)); + assertEquals(false, bf.testDouble(val2)); + assertEquals(false, bf.testDouble(val3)); + bf.addDouble(val1); + assertEquals(true, bf.testDouble(val)); + assertEquals(true, bf.testDouble(val1)); + assertEquals(false, bf.testDouble(val2)); + assertEquals(false, bf.testDouble(val3)); + bf.addDouble(val2); + assertEquals(true, bf.testDouble(val)); + assertEquals(true, bf.testDouble(val1)); + assertEquals(true, bf.testDouble(val2)); + assertEquals(false, bf.testDouble(val3)); + bf.addDouble(val3); + assertEquals(true, bf.testDouble(val)); + assertEquals(true, bf.testDouble(val1)); + assertEquals(true, bf.testDouble(val2)); + assertEquals(true, bf.testDouble(val3)); + + double randVal = 0; + for (int i = 0; i < COUNT; i++) { + randVal = rand.nextDouble(); + bf.addDouble(randVal); + } + // last value should be present + assertEquals(true, bf.testDouble(randVal)); + // most likely this value should not exist + assertEquals(false, bf.testDouble(-120.2d)); + + assertEquals(7808, bf.sizeInBytes()); + } + + @Test + public void testBloomKFilterString() { + BloomKFilter bf = new BloomKFilter(100000); + String val = "bloo"; + String val1 = "bloom fil"; + String val2 = "bloom filter"; + String val3 = "cuckoo filter"; + + assertEquals(false, bf.testString(val)); + assertEquals(false, bf.testString(val1)); + assertEquals(false, bf.testString(val2)); + assertEquals(false, bf.testString(val3)); + bf.addString(val); + assertEquals(true, bf.testString(val)); + assertEquals(false, bf.testString(val1)); + assertEquals(false, bf.testString(val2)); + assertEquals(false, bf.testString(val3)); + bf.addString(val1); + assertEquals(true, bf.testString(val)); + assertEquals(true, bf.testString(val1)); + assertEquals(false, bf.testString(val2)); + assertEquals(false, bf.testString(val3)); + bf.addString(val2); + assertEquals(true, bf.testString(val)); + assertEquals(true, bf.testString(val1)); + assertEquals(true, bf.testString(val2)); + assertEquals(false, bf.testString(val3)); + bf.addString(val3); + assertEquals(true, bf.testString(val)); + assertEquals(true, bf.testString(val1)); + assertEquals(true, bf.testString(val2)); + assertEquals(true, bf.testString(val3)); + + long randVal = 0; + for (int i = 0; i < COUNT; i++) { + randVal = rand.nextLong(); + bf.addString(Long.toString(randVal)); + } + // last value should be present + assertEquals(true, bf.testString(Long.toString(randVal))); + // most likely this value should not exist + assertEquals(false, bf.testString(Long.toString(-120))); + + assertEquals(77952, bf.sizeInBytes()); + } + + @Test + public void testMerge() { + BloomKFilter bf = new BloomKFilter(10000); + String val = "bloo"; + String val1 = "bloom fil"; + String val2 = "bloom filter"; + String val3 = "cuckoo filter"; + bf.addString(val); + bf.addString(val1); + bf.addString(val2); + bf.addString(val3); + + BloomKFilter bf2 = new BloomKFilter(10000); + String v = "2_bloo"; + String v1 = "2_bloom fil"; + String v2 = "2_bloom filter"; + String v3 = "2_cuckoo filter"; + bf2.addString(v); + bf2.addString(v1); + bf2.addString(v2); + bf2.addString(v3); + + assertEquals(true, bf.testString(val)); + assertEquals(true, bf.testString(val1)); + assertEquals(true, bf.testString(val2)); + assertEquals(true, bf.testString(val3)); + assertEquals(false, bf.testString(v)); + assertEquals(false, bf.testString(v1)); + assertEquals(false, bf.testString(v2)); + assertEquals(false, bf.testString(v3)); + + bf.merge(bf2); + + assertEquals(true, bf.testString(val)); + assertEquals(true, bf.testString(val1)); + assertEquals(true, bf.testString(val2)); + assertEquals(true, bf.testString(val3)); + assertEquals(true, bf.testString(v)); + assertEquals(true, bf.testString(v1)); + assertEquals(true, bf.testString(v2)); + assertEquals(true, bf.testString(v3)); + } + + @Test + public void testSerialize() throws Exception { + BloomKFilter bf1 = new BloomKFilter(10000); + String[] inputs = { + "bloo", + "bloom fil", + "bloom filter", + "cuckoo filter", + }; + + for (String val : inputs) { + bf1.addString(val); + } + + // Serialize/deserialize + ByteArrayOutputStream bytesOut = new ByteArrayOutputStream(); + BloomKFilter.serialize(bytesOut, bf1); + ByteArrayInputStream bytesIn = new ByteArrayInputStream(bytesOut.toByteArray()); + BloomKFilter bf2 = BloomKFilter.deserialize(bytesIn); + + for (String val : inputs) { + assertEquals("Testing bf1 with " + val, true, bf1.testString(val)); + assertEquals("Testing bf2 with " + val, true, bf2.testString(val)); + } + } + + @Test + public void testMergeBloomKFilterBytes() throws Exception { + BloomKFilter bf1 = new BloomKFilter(10000); + BloomKFilter bf2 = new BloomKFilter(10000); + + String[] inputs1 = { + "bloo", + "bloom fil", + "bloom filter", + "cuckoo filter", + }; + + String[] inputs2 = { + "2_bloo", + "2_bloom fil", + "2_bloom filter", + "2_cuckoo filter", + }; + + for (String val : inputs1) { + bf1.addString(val); + } + for (String val : inputs2) { + bf2.addString(val); + } + + ByteArrayOutputStream bytesOut = new ByteArrayOutputStream(); + BloomKFilter.serialize(bytesOut, bf1); + byte[] bf1Bytes = bytesOut.toByteArray(); + bytesOut.reset(); + BloomKFilter.serialize(bytesOut, bf1); + byte[] bf2Bytes = bytesOut.toByteArray(); + + // Merge bytes + BloomKFilter.mergeBloomFilterBytes( + bf1Bytes, 0, bf1Bytes.length, + bf2Bytes, 0, bf2Bytes.length); + + // Deserialize and test + ByteArrayInputStream bytesIn = new ByteArrayInputStream(bf1Bytes, 0, bf1Bytes.length); + BloomKFilter bfMerged = BloomKFilter.deserialize(bytesIn); + // All values should pass test + for (String val : inputs1) { + bfMerged.addString(val); + } + for (String val : inputs2) { + bfMerged.addString(val); + } + } + + @Test + public void testMergeBloomKFilterBytesFailureCases() throws Exception { + BloomKFilter bf1 = new BloomKFilter(1000); + BloomKFilter bf2 = new BloomKFilter(200); + // Create bloom filter with same number of bits, but different # hash functions + long[] bits = new long[bf1.getBitSet().length]; + BloomKFilter bf3 = new BloomKFilter(bits, bf1.getNumHashFunctions() + 1); + + // Serialize to bytes + ByteArrayOutputStream bytesOut = new ByteArrayOutputStream(); + BloomKFilter.serialize(bytesOut, bf1); + byte[] bf1Bytes = bytesOut.toByteArray(); + + bytesOut.reset(); + BloomKFilter.serialize(bytesOut, bf2); + byte[] bf2Bytes = bytesOut.toByteArray(); + + bytesOut.reset(); + BloomKFilter.serialize(bytesOut, bf3); + byte[] bf3Bytes = bytesOut.toByteArray(); + + try { + // this should fail + BloomKFilter.mergeBloomFilterBytes( + bf1Bytes, 0, bf1Bytes.length, + bf2Bytes, 0, bf2Bytes.length); + Assert.fail("Expected exception not encountered"); + } catch (IllegalArgumentException err) { + // expected + } + + try { + // this should fail + BloomKFilter.mergeBloomFilterBytes( + bf1Bytes, 0, bf1Bytes.length, + bf3Bytes, 0, bf3Bytes.length); + Assert.fail("Expected exception not encountered"); + } catch (IllegalArgumentException err) { + // expected + } + } + + @Test + public void testFpp1K() { + int size = 1000; + BloomKFilter bf = new BloomKFilter(size); + int fp = 0; + for (int i = 0; i < size; i++) { + bf.addLong(i); + } + + for (int i = 0; i < size; i++) { + assertTrue(bf.testLong(i)); + } + + for (int i = 0; i < size; i++) { + int probe = rand.nextInt(); + // out of range probes + if ((probe > size) || (probe < 0)) { + if (bf.testLong(probe)) { + fp++; + } + } + } + + double actualFpp = (double) fp / (double) size; + double expectedFpp = bf.DEFAULT_FPP; + if (actualFpp < expectedFpp) { + assertTrue(actualFpp != 0.0); + } else { + assertEquals(expectedFpp, actualFpp, deltaError); + } + } + + @Test + public void testFpp10K() { + int size = 10_000; + BloomKFilter bf = new BloomKFilter(size); + int fp = 0; + for (int i = 0; i < size; i++) { + bf.addLong(i); + } + + for (int i = 0; i < size; i++) { + assertTrue(bf.testLong(i)); + } + + for (int i = 0; i < size; i++) { + int probe = rand.nextInt(); + // out of range probes + if ((probe > size) || (probe < 0)) { + if (bf.testLong(probe)) { + fp++; + } + } + } + + double actualFpp = (double) fp / (double) size; + double expectedFpp = bf.DEFAULT_FPP; + if (actualFpp < expectedFpp) { + assertTrue(actualFpp != 0.0); + } else { + assertEquals(expectedFpp, actualFpp, deltaError); + } + } + + @Test + public void testFpp1M() { + int size = 1_000_000; + BloomKFilter bf = new BloomKFilter(size); + int fp = 0; + for (int i = 0; i < size; i++) { + bf.addLong(i); + } + + for (int i = 0; i < size; i++) { + assertTrue(bf.testLong(i)); + } + + for (int i = 0; i < size; i++) { + int probe = rand.nextInt(); + // out of range probes + if ((probe > size) || (probe < 0)) { + if (bf.testLong(probe)) { + fp++; + } + } + } + + double actualFpp = (double) fp / (double) size; + double expectedFpp = bf.DEFAULT_FPP; + if (actualFpp < expectedFpp) { + assertTrue(actualFpp != 0.0); + } else { + assertEquals(expectedFpp, actualFpp, deltaError); + } + } + + @Test + public void testFpp10M() { + int size = 10_000_000; + BloomKFilter bf = new BloomKFilter(size); + int fp = 0; + for (int i = 0; i < size; i++) { + bf.addLong(i); + } + + for (int i = 0; i < size; i++) { + assertTrue(bf.testLong(i)); + } + + for (int i = 0; i < size; i++) { + int probe = rand.nextInt(); + // out of range probes + if ((probe > size) || (probe < 0)) { + if (bf.testLong(probe)) { + fp++; + } + } + } + + double actualFpp = (double) fp / (double) size; + double expectedFpp = bf.DEFAULT_FPP; + if (actualFpp < expectedFpp) { + assertTrue(actualFpp != 0.0); + } else { + assertEquals(expectedFpp, actualFpp, deltaError); + } + } +}