diff --git lucene/core/src/java/org/apache/lucene/util/packed/DirectPackedReader.java lucene/core/src/java/org/apache/lucene/util/packed/DirectPackedReader.java index 009124e..fd8ac7c 100644 --- lucene/core/src/java/org/apache/lucene/util/packed/DirectPackedReader.java +++ lucene/core/src/java/org/apache/lucene/util/packed/DirectPackedReader.java @@ -25,12 +25,18 @@ import java.io.IOException; class DirectPackedReader extends PackedInts.ReaderImpl { private final IndexInput in; private final long startPointer; + private final long valueMask; public DirectPackedReader(int bitsPerValue, int valueCount, IndexInput in) { super(valueCount, bitsPerValue); this.in = in; startPointer = in.getFilePointer(); + if (bitsPerValue == 64) { + valueMask = -1L; + } else { + valueMask = (1L << bitsPerValue) - 1; + } } @Override @@ -40,29 +46,49 @@ class DirectPackedReader extends PackedInts.ReaderImpl { try { in.seek(startPointer + elementPos); - final byte b0 = in.readByte(); final int bitPos = (int) (majorBitPos & 7); - if (bitPos + bitsPerValue <= 8) { - // special case: all bits are in the first byte - return (b0 & ((1L << (8 - bitPos)) - 1)) >>> (8 - bitPos - bitsPerValue); - } - - // take bits from the first byte - int remainingBits = bitsPerValue - 8 + bitPos; - long result = (b0 & ((1L << (8 - bitPos)) - 1)) << remainingBits; - - // add bits from inner bytes - while (remainingBits >= 8) { - remainingBits -= 8; - result |= (in.readByte() & 0xFFL) << remainingBits; - } + // round up bits to a multiple of 8 to find total bytes needed to read + final int roundedBits = ((bitPos + bitsPerValue + 7) & ~7); + // the number of extra bits read at the end to shift out + int shiftRightBits = roundedBits - bitPos - bitsPerValue; - // take bits from the last byte - if (remainingBits > 0) { - result |= (in.readByte() & 0xFFL) >>> (8 - remainingBits); + long rawValue; + switch (roundedBits / 8) { + case 1: + rawValue = in.readByte(); + break; + case 2: + rawValue = in.readShort(); + break; + case 3: + rawValue = ((long)in.readShort() << 8) | (in.readByte() & 0xFFL); + break; + case 4: + rawValue = in.readInt(); + break; + case 5: + rawValue = ((long)in.readInt() << 8) | (in.readByte() & 0xFFL); + break; + case 6: + rawValue = ((long)in.readInt() << 16) | (in.readShort() & 0xFFFFL); + break; + case 7: + rawValue = ((long)in.readInt() << 24) | ((in.readShort() & 0xFFFFL) << 8) | (in.readByte() & 0xFFL); + break; + case 8: + rawValue = in.readLong(); + break; + case 9: + // We must be very careful not to shift out relevant bits. So we account for right shift + // we would normally do on return here, and reset it. + rawValue = (in.readLong() << (8 - shiftRightBits)) | ((in.readByte() & 0xFFL) >>> shiftRightBits); + shiftRightBits = 0; + break; + default: + throw new AssertionError("bitsPerValue too large: " + bitsPerValue); } + return (rawValue >>> shiftRightBits) & valueMask; - return result; } catch (IOException ioe) { throw new IllegalStateException("failed", ioe); }