diff --git a/serde/src/java/org/apache/hadoop/hive/serde2/thrift/ColumnBuffer.java b/serde/src/java/org/apache/hadoop/hive/serde2/thrift/ColumnBuffer.java index 3ce1fb3..d603963 100644 --- a/serde/src/java/org/apache/hadoop/hive/serde2/thrift/ColumnBuffer.java +++ b/serde/src/java/org/apache/hadoop/hive/serde2/thrift/ColumnBuffer.java @@ -25,6 +25,7 @@ import java.util.BitSet; import java.util.List; +import com.google.common.annotations.VisibleForTesting; import org.apache.hive.service.rpc.thrift.TBinaryColumn; import org.apache.hive.service.rpc.thrift.TBoolColumn; import org.apache.hive.service.rpc.thrift.TByteColumn; @@ -177,73 +178,83 @@ public ColumnBuffer(TColumn colValues) { } } - public ColumnBuffer extractSubset(int start, int end) { - BitSet subNulls = nulls.get(start, end); + /** + * Get a subset of this ColumnBuffer, starting from the 1st value. + * + * @param end index after the last value to include + */ + public ColumnBuffer extractSubset(int end) { + BitSet subNulls = nulls.get(0, end); if (type == Type.BOOLEAN_TYPE) { ColumnBuffer subset = - new ColumnBuffer(type, subNulls, Arrays.copyOfRange(boolVars, start, end)); + new ColumnBuffer(type, subNulls, Arrays.copyOfRange(boolVars, 0, end)); boolVars = Arrays.copyOfRange(boolVars, end, size); - nulls = nulls.get(start, size); + nulls = nulls.get(end, size); size = boolVars.length; return subset; } if (type == Type.TINYINT_TYPE) { ColumnBuffer subset = - new ColumnBuffer(type, subNulls, Arrays.copyOfRange(byteVars, start, end)); + new ColumnBuffer(type, subNulls, Arrays.copyOfRange(byteVars, 0, end)); byteVars = Arrays.copyOfRange(byteVars, end, size); - nulls = nulls.get(start, size); + nulls = nulls.get(end, size); size = byteVars.length; return subset; } if (type == Type.SMALLINT_TYPE) { ColumnBuffer subset = - new ColumnBuffer(type, subNulls, Arrays.copyOfRange(shortVars, start, end)); + new ColumnBuffer(type, subNulls, Arrays.copyOfRange(shortVars, 0, end)); shortVars = Arrays.copyOfRange(shortVars, end, size); - nulls = nulls.get(start, size); + nulls = nulls.get(end, size); size = shortVars.length; return subset; } if (type == Type.INT_TYPE) { ColumnBuffer subset = - new ColumnBuffer(type, subNulls, Arrays.copyOfRange(intVars, start, end)); + new ColumnBuffer(type, subNulls, Arrays.copyOfRange(intVars, 0, end)); intVars = Arrays.copyOfRange(intVars, end, size); - nulls = nulls.get(start, size); + nulls = nulls.get(end, size); size = intVars.length; return subset; } if (type == Type.BIGINT_TYPE) { ColumnBuffer subset = - new ColumnBuffer(type, subNulls, Arrays.copyOfRange(longVars, start, end)); + new ColumnBuffer(type, subNulls, Arrays.copyOfRange(longVars, 0, end)); longVars = Arrays.copyOfRange(longVars, end, size); - nulls = nulls.get(start, size); + nulls = nulls.get(end, size); size = longVars.length; return subset; } if (type == Type.DOUBLE_TYPE || type == Type.FLOAT_TYPE) { ColumnBuffer subset = - new ColumnBuffer(type, subNulls, Arrays.copyOfRange(doubleVars, start, end)); + new ColumnBuffer(type, subNulls, Arrays.copyOfRange(doubleVars, 0, end)); doubleVars = Arrays.copyOfRange(doubleVars, end, size); - nulls = nulls.get(start, size); + nulls = nulls.get(end, size); size = doubleVars.length; return subset; } if (type == Type.BINARY_TYPE) { - ColumnBuffer subset = new ColumnBuffer(type, subNulls, binaryVars.subList(start, end)); + ColumnBuffer subset = new ColumnBuffer(type, subNulls, binaryVars.subList(0, end)); binaryVars = binaryVars.subList(end, binaryVars.size()); - nulls = nulls.get(start, size); + nulls = nulls.get(end, size); size = binaryVars.size(); return subset; } if (type == Type.STRING_TYPE) { - ColumnBuffer subset = new ColumnBuffer(type, subNulls, stringVars.subList(start, end)); + ColumnBuffer subset = new ColumnBuffer(type, subNulls, stringVars.subList(0, end)); stringVars = stringVars.subList(end, stringVars.size()); - nulls = nulls.get(start, size); + nulls = nulls.get(end, size); size = stringVars.size(); return subset; } throw new IllegalStateException("invalid union object"); } + @VisibleForTesting + BitSet getNulls() { + return nulls; + } + private static final byte[] MASKS = new byte[] { 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, (byte)0x80 }; diff --git a/serde/src/test/org/apache/hadoop/hive/serde2/thrift/TestColumnBuffer.java b/serde/src/test/org/apache/hadoop/hive/serde2/thrift/TestColumnBuffer.java new file mode 100644 index 0000000..610636b --- /dev/null +++ b/serde/src/test/org/apache/hadoop/hive/serde2/thrift/TestColumnBuffer.java @@ -0,0 +1,136 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + *

+ * http://www.apache.org/licenses/LICENSE-2.0 + *

+ * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.hadoop.hive.serde2.thrift; + +import com.google.code.tempusfugit.concurrency.RepeatingRule; +import com.google.code.tempusfugit.concurrency.annotations.Repeating; +import org.apache.hadoop.hive.serde2.thrift.ColumnBuffer; +import org.apache.hadoop.hive.serde2.thrift.Type; +import org.junit.Assert; +import org.junit.Rule; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.Parameterized; + +import java.nio.ByteBuffer; +import java.util.Arrays; +import java.util.BitSet; +import java.util.Collection; +import java.util.HashSet; +import java.util.Random; +import java.util.Set; +import java.util.concurrent.ThreadLocalRandom; + +@RunWith(Parameterized.class) +public class TestColumnBuffer { + @Rule + public RepeatingRule repeatingRule = new RepeatingRule(); + + private static final int NUM_VARS = 100; + private static final int NUM_NULLS = 30; + private static final Set nullIndices = new HashSet<>(); + + private final Type type; + private final Object vars; + + @Parameterized.Parameters + public static Collection types() { + return Arrays.asList(new Object[][]{ + {Type.BOOLEAN_TYPE}, + {Type.TINYINT_TYPE}, + {Type.SMALLINT_TYPE}, + {Type.INT_TYPE}, + {Type.BIGINT_TYPE}, + {Type.DOUBLE_TYPE}, + {Type.FLOAT_TYPE}, + {Type.BINARY_TYPE}, + {Type.STRING_TYPE} + } + ); + } + + public TestColumnBuffer(Type type) { + this.type = type; + switch (type) { + case BOOLEAN_TYPE: + vars = new boolean[NUM_VARS]; + break; + case TINYINT_TYPE: + vars = new byte[NUM_VARS]; + break; + case SMALLINT_TYPE: + vars = new short[NUM_VARS]; + break; + case INT_TYPE: + vars = new int[NUM_VARS]; + break; + case BIGINT_TYPE: + vars = new long[NUM_VARS]; + break; + case DOUBLE_TYPE: + case FLOAT_TYPE: + vars = new double[NUM_VARS]; + break; + case BINARY_TYPE: + vars = Arrays.asList(new ByteBuffer[NUM_VARS]); + break; + case STRING_TYPE: + vars = Arrays.asList(new String[NUM_VARS]); + break; + default: + throw new IllegalArgumentException("Invalid type " + type); + } + } + + private static void prepareNullIndices() { + nullIndices.clear(); + Random random = ThreadLocalRandom.current(); + while (nullIndices.size() != NUM_NULLS) { + nullIndices.add(random.nextInt(NUM_VARS)); + } + } + + @Test + @Repeating(repetition=10) + public void testExtractSubset() { + prepareNullIndices(); + BitSet nulls = new BitSet(NUM_VARS); + for (int index : nullIndices) { + nulls.set(index); + } + + ColumnBuffer columnBuffer = new ColumnBuffer(type, nulls, vars); + Random random = ThreadLocalRandom.current(); + + int remaining = NUM_VARS; + while (remaining > 0) { + int toExtract = random.nextInt(remaining) + 1; + ColumnBuffer subset = columnBuffer.extractSubset(toExtract); + verifyNulls(subset, NUM_VARS - remaining); + remaining -= toExtract; + } + } + + private static void verifyNulls(ColumnBuffer buffer, int shift) { + BitSet nulls = buffer.getNulls(); + for (int i = 0; i < buffer.size(); i++) { + Assert.assertEquals("BitSet in parent and subset not the same.", + nullIndices.contains(i + shift), nulls.get(i)); + } + } +} diff --git a/service/src/java/org/apache/hive/service/cli/ColumnBasedSet.java b/service/src/java/org/apache/hive/service/cli/ColumnBasedSet.java index 9cbe89c..3774426 100644 --- a/service/src/java/org/apache/hive/service/cli/ColumnBasedSet.java +++ b/service/src/java/org/apache/hive/service/cli/ColumnBasedSet.java @@ -137,7 +137,7 @@ public ColumnBasedSet extractSubset(int maxRows) { List subset = new ArrayList(); for (int i = 0; i < columns.size(); i++) { - subset.add(columns.get(i).extractSubset(0, numRows)); + subset.add(columns.get(i).extractSubset(numRows)); } ColumnBasedSet result = new ColumnBasedSet(descriptors, subset, startOffset); startOffset += numRows;