diff --git hbase-spark/src/main/java/org/apache/hadoop/hbase/spark/SparkSQLPushDownFilter.java hbase-spark/src/main/java/org/apache/hadoop/hbase/spark/SparkSQLPushDownFilter.java index 057853f..249831e 100644 --- hbase-spark/src/main/java/org/apache/hadoop/hbase/spark/SparkSQLPushDownFilter.java +++ hbase-spark/src/main/java/org/apache/hadoop/hbase/spark/SparkSQLPushDownFilter.java @@ -24,6 +24,8 @@ import org.apache.commons.logging.LogFactory; import org.apache.hadoop.hbase.Cell; import org.apache.hadoop.hbase.exceptions.DeserializationException; import org.apache.hadoop.hbase.filter.FilterBase; +import org.apache.hadoop.hbase.spark.datasources.BytesEncoder; +import org.apache.hadoop.hbase.spark.datasources.JavaBytesEncoder; import org.apache.hadoop.hbase.spark.protobuf.generated.FilterProtos; import org.apache.hadoop.hbase.util.ByteStringer; import org.apache.hadoop.hbase.util.Bytes; @@ -56,21 +58,25 @@ public class SparkSQLPushDownFilter extends FilterBase{ static final byte[] rowKeyFamily = new byte[0]; static final byte[] rowKeyQualifier = Bytes.toBytes("key"); + String encoder; + public SparkSQLPushDownFilter(DynamicLogicExpression dynamicLogicExpression, byte[][] valueFromQueryArray, HashMap> - currentCellToColumnIndexMap) { + currentCellToColumnIndexMap, String encoder) { this.dynamicLogicExpression = dynamicLogicExpression; this.valueFromQueryArray = valueFromQueryArray; this.currentCellToColumnIndexMap = currentCellToColumnIndexMap; + this.encoder = encoder; } public SparkSQLPushDownFilter(DynamicLogicExpression dynamicLogicExpression, byte[][] valueFromQueryArray, - MutableList fields) { + MutableList fields, String encoder) { this.dynamicLogicExpression = dynamicLogicExpression; this.valueFromQueryArray = valueFromQueryArray; + this.encoder = encoder; //generate family qualifier to index mapping this.currentCellToColumnIndexMap = @@ -184,9 +190,12 @@ public class SparkSQLPushDownFilter extends FilterBase{ throw new DeserializationException(e); } + String encoder = proto.getEncoder(); + BytesEncoder enc = JavaBytesEncoder.create(encoder); + //Load DynamicLogicExpression DynamicLogicExpression dynamicLogicExpression = - DynamicLogicExpressionBuilder.build(proto.getDynamicLogicExpression()); + DynamicLogicExpressionBuilder.build(proto.getDynamicLogicExpression(), enc); //Load valuesFromQuery final List valueFromQueryArrayList = proto.getValueFromQueryArrayList(); @@ -225,7 +234,7 @@ public class SparkSQLPushDownFilter extends FilterBase{ } return new SparkSQLPushDownFilter(dynamicLogicExpression, - valueFromQueryArray, currentCellToColumnIndexMap); + valueFromQueryArray, currentCellToColumnIndexMap, encoder); } /** @@ -256,6 +265,8 @@ public class SparkSQLPushDownFilter extends FilterBase{ builder.addCellToColumnMapping(columnMappingBuilder.build()); } } + builder.setEncoder(encoder); + return builder.build().toByteArray(); } diff --git hbase-spark/src/main/java/org/apache/hadoop/hbase/spark/protobuf/generated/FilterProtos.java hbase-spark/src/main/java/org/apache/hadoop/hbase/spark/protobuf/generated/FilterProtos.java index 1968d32..c909e90 100644 --- hbase-spark/src/main/java/org/apache/hadoop/hbase/spark/protobuf/generated/FilterProtos.java +++ hbase-spark/src/main/java/org/apache/hadoop/hbase/spark/protobuf/generated/FilterProtos.java @@ -783,6 +783,21 @@ public final class FilterProtos { */ org.apache.hadoop.hbase.spark.protobuf.generated.FilterProtos.SQLPredicatePushDownCellToColumnMappingOrBuilder getCellToColumnMappingOrBuilder( int index); + + // required string encoder = 4; + /** + * required string encoder = 4; + */ + boolean hasEncoder(); + /** + * required string encoder = 4; + */ + java.lang.String getEncoder(); + /** + * required string encoder = 4; + */ + com.google.protobuf.ByteString + getEncoderBytes(); } /** * Protobuf type {@code hbase.pb.SQLPredicatePushDownFilter} @@ -856,6 +871,11 @@ public final class FilterProtos { cellToColumnMapping_.add(input.readMessage(org.apache.hadoop.hbase.spark.protobuf.generated.FilterProtos.SQLPredicatePushDownCellToColumnMapping.PARSER, extensionRegistry)); break; } + case 34: { + bitField0_ |= 0x00000002; + encoder_ = input.readBytes(); + break; + } } } } catch (com.google.protobuf.InvalidProtocolBufferException e) { @@ -1004,10 +1024,54 @@ public final class FilterProtos { return cellToColumnMapping_.get(index); } + // required string encoder = 4; + public static final int ENCODER_FIELD_NUMBER = 4; + private java.lang.Object encoder_; + /** + * required string encoder = 4; + */ + public boolean hasEncoder() { + return ((bitField0_ & 0x00000002) == 0x00000002); + } + /** + * required string encoder = 4; + */ + public java.lang.String getEncoder() { + java.lang.Object ref = encoder_; + if (ref instanceof java.lang.String) { + return (java.lang.String) ref; + } else { + com.google.protobuf.ByteString bs = + (com.google.protobuf.ByteString) ref; + java.lang.String s = bs.toStringUtf8(); + if (bs.isValidUtf8()) { + encoder_ = s; + } + return s; + } + } + /** + * required string encoder = 4; + */ + public com.google.protobuf.ByteString + getEncoderBytes() { + java.lang.Object ref = encoder_; + if (ref instanceof java.lang.String) { + com.google.protobuf.ByteString b = + com.google.protobuf.ByteString.copyFromUtf8( + (java.lang.String) ref); + encoder_ = b; + return b; + } else { + return (com.google.protobuf.ByteString) ref; + } + } + private void initFields() { dynamicLogicExpression_ = ""; valueFromQueryArray_ = java.util.Collections.emptyList(); cellToColumnMapping_ = java.util.Collections.emptyList(); + encoder_ = ""; } private byte memoizedIsInitialized = -1; public final boolean isInitialized() { @@ -1018,6 +1082,10 @@ public final class FilterProtos { memoizedIsInitialized = 0; return false; } + if (!hasEncoder()) { + memoizedIsInitialized = 0; + return false; + } for (int i = 0; i < getCellToColumnMappingCount(); i++) { if (!getCellToColumnMapping(i).isInitialized()) { memoizedIsInitialized = 0; @@ -1040,6 +1108,9 @@ public final class FilterProtos { for (int i = 0; i < cellToColumnMapping_.size(); i++) { output.writeMessage(3, cellToColumnMapping_.get(i)); } + if (((bitField0_ & 0x00000002) == 0x00000002)) { + output.writeBytes(4, getEncoderBytes()); + } getUnknownFields().writeTo(output); } @@ -1066,6 +1137,10 @@ public final class FilterProtos { size += com.google.protobuf.CodedOutputStream .computeMessageSize(3, cellToColumnMapping_.get(i)); } + if (((bitField0_ & 0x00000002) == 0x00000002)) { + size += com.google.protobuf.CodedOutputStream + .computeBytesSize(4, getEncoderBytes()); + } size += getUnknownFields().getSerializedSize(); memoizedSerializedSize = size; return size; @@ -1098,6 +1173,11 @@ public final class FilterProtos { .equals(other.getValueFromQueryArrayList()); result = result && getCellToColumnMappingList() .equals(other.getCellToColumnMappingList()); + result = result && (hasEncoder() == other.hasEncoder()); + if (hasEncoder()) { + result = result && getEncoder() + .equals(other.getEncoder()); + } result = result && getUnknownFields().equals(other.getUnknownFields()); return result; @@ -1123,6 +1203,10 @@ public final class FilterProtos { hash = (37 * hash) + CELL_TO_COLUMN_MAPPING_FIELD_NUMBER; hash = (53 * hash) + getCellToColumnMappingList().hashCode(); } + if (hasEncoder()) { + hash = (37 * hash) + ENCODER_FIELD_NUMBER; + hash = (53 * hash) + getEncoder().hashCode(); + } hash = (29 * hash) + getUnknownFields().hashCode(); memoizedHashCode = hash; return hash; @@ -1243,6 +1327,8 @@ public final class FilterProtos { } else { cellToColumnMappingBuilder_.clear(); } + encoder_ = ""; + bitField0_ = (bitField0_ & ~0x00000008); return this; } @@ -1289,6 +1375,10 @@ public final class FilterProtos { } else { result.cellToColumnMapping_ = cellToColumnMappingBuilder_.build(); } + if (((from_bitField0_ & 0x00000008) == 0x00000008)) { + to_bitField0_ |= 0x00000002; + } + result.encoder_ = encoder_; result.bitField0_ = to_bitField0_; onBuilt(); return result; @@ -1346,6 +1436,11 @@ public final class FilterProtos { } } } + if (other.hasEncoder()) { + bitField0_ |= 0x00000008; + encoder_ = other.encoder_; + onChanged(); + } this.mergeUnknownFields(other.getUnknownFields()); return this; } @@ -1355,6 +1450,10 @@ public final class FilterProtos { return false; } + if (!hasEncoder()) { + + return false; + } for (int i = 0; i < getCellToColumnMappingCount(); i++) { if (!getCellToColumnMapping(i).isInitialized()) { @@ -1769,6 +1868,80 @@ public final class FilterProtos { return cellToColumnMappingBuilder_; } + // required string encoder = 4; + private java.lang.Object encoder_ = ""; + /** + * required string encoder = 4; + */ + public boolean hasEncoder() { + return ((bitField0_ & 0x00000008) == 0x00000008); + } + /** + * required string encoder = 4; + */ + public java.lang.String getEncoder() { + java.lang.Object ref = encoder_; + if (!(ref instanceof java.lang.String)) { + java.lang.String s = ((com.google.protobuf.ByteString) ref) + .toStringUtf8(); + encoder_ = s; + return s; + } else { + return (java.lang.String) ref; + } + } + /** + * required string encoder = 4; + */ + public com.google.protobuf.ByteString + getEncoderBytes() { + java.lang.Object ref = encoder_; + if (ref instanceof String) { + com.google.protobuf.ByteString b = + com.google.protobuf.ByteString.copyFromUtf8( + (java.lang.String) ref); + encoder_ = b; + return b; + } else { + return (com.google.protobuf.ByteString) ref; + } + } + /** + * required string encoder = 4; + */ + public Builder setEncoder( + java.lang.String value) { + if (value == null) { + throw new NullPointerException(); + } + bitField0_ |= 0x00000008; + encoder_ = value; + onChanged(); + return this; + } + /** + * required string encoder = 4; + */ + public Builder clearEncoder() { + bitField0_ = (bitField0_ & ~0x00000008); + encoder_ = getDefaultInstance().getEncoder(); + onChanged(); + return this; + } + /** + * required string encoder = 4; + */ + public Builder setEncoderBytes( + com.google.protobuf.ByteString value) { + if (value == null) { + throw new NullPointerException(); + } + bitField0_ |= 0x00000008; + encoder_ = value; + onChanged(); + return this; + } + // @@protoc_insertion_point(builder_scope:hbase.pb.SQLPredicatePushDownFilter) } @@ -1802,13 +1975,14 @@ public final class FilterProtos { "\n\014Filter.proto\022\010hbase.pb\"h\n\'SQLPredicate" + "PushDownCellToColumnMapping\022\025\n\rcolumn_fa" + "mily\030\001 \002(\014\022\021\n\tqualifier\030\002 \002(\014\022\023\n\013column_" + - "name\030\003 \002(\t\"\261\001\n\032SQLPredicatePushDownFilte" + + "name\030\003 \002(\t\"\302\001\n\032SQLPredicatePushDownFilte" + "r\022 \n\030dynamic_logic_expression\030\001 \002(\t\022\036\n\026v" + "alue_from_query_array\030\002 \003(\014\022Q\n\026cell_to_c" + "olumn_mapping\030\003 \003(\01321.hbase.pb.SQLPredic" + - "atePushDownCellToColumnMappingBH\n0org.ap" + - "ache.hadoop.hbase.spark.protobuf.generat" + - "edB\014FilterProtosH\001\210\001\001\240\001\001" + "atePushDownCellToColumnMapping\022\017\n\007encode" + + "r\030\004 \002(\tBH\n0org.apache.hadoop.hbase.spark" + + ".protobuf.generatedB\014FilterProtosH\001\210\001\001\240\001", + "\001" }; com.google.protobuf.Descriptors.FileDescriptor.InternalDescriptorAssigner assigner = new com.google.protobuf.Descriptors.FileDescriptor.InternalDescriptorAssigner() { @@ -1826,7 +2000,7 @@ public final class FilterProtos { internal_static_hbase_pb_SQLPredicatePushDownFilter_fieldAccessorTable = new com.google.protobuf.GeneratedMessage.FieldAccessorTable( internal_static_hbase_pb_SQLPredicatePushDownFilter_descriptor, - new java.lang.String[] { "DynamicLogicExpression", "ValueFromQueryArray", "CellToColumnMapping", }); + new java.lang.String[] { "DynamicLogicExpression", "ValueFromQueryArray", "CellToColumnMapping", "Encoder", }); return null; } }; diff --git hbase-spark/src/main/protobuf/Filter.proto hbase-spark/src/main/protobuf/Filter.proto index e076ce8..47b95ff 100644 --- hbase-spark/src/main/protobuf/Filter.proto +++ hbase-spark/src/main/protobuf/Filter.proto @@ -35,4 +35,5 @@ message SQLPredicatePushDownFilter { required string dynamic_logic_expression = 1; repeated bytes value_from_query_array = 2; repeated SQLPredicatePushDownCellToColumnMapping cell_to_column_mapping = 3; + required string encoder = 4; } \ No newline at end of file diff --git hbase-spark/src/main/scala/org/apache/hadoop/hbase/spark/DefaultSource.scala hbase-spark/src/main/scala/org/apache/hadoop/hbase/spark/DefaultSource.scala index 1697036..8b15cc7 100644 --- hbase-spark/src/main/scala/org/apache/hadoop/hbase/spark/DefaultSource.scala +++ hbase-spark/src/main/scala/org/apache/hadoop/hbase/spark/DefaultSource.scala @@ -23,9 +23,7 @@ import java.util.concurrent.ConcurrentLinkedQueue import org.apache.hadoop.hbase.client._ import org.apache.hadoop.hbase.io.ImmutableBytesWritable import org.apache.hadoop.hbase.mapred.TableOutputFormat -import org.apache.hadoop.hbase.spark.datasources.HBaseSparkConf -import org.apache.hadoop.hbase.spark.datasources.HBaseTableScanRDD -import org.apache.hadoop.hbase.spark.datasources.SerializableConfiguration +import org.apache.hadoop.hbase.spark.datasources._ import org.apache.hadoop.hbase.types._ import org.apache.hadoop.hbase.util.{Bytes, PositionedByteRange, SimplePositionedMutableByteRange} import org.apache.hadoop.hbase._ @@ -92,6 +90,9 @@ case class HBaseRelation ( val minTimestamp = parameters.get(HBaseSparkConf.MIN_TIMESTAMP).map(_.toLong) val maxTimestamp = parameters.get(HBaseSparkConf.MAX_TIMESTAMP).map(_.toLong) val maxVersions = parameters.get(HBaseSparkConf.MAX_VERSIONS).map(_.toInt) + val encoderClsName = parameters.get(HBaseSparkConf.ENCODER).getOrElse(HBaseSparkConf.defaultEncoder) + + @transient val encoder = JavaBytesEncoder.create(encoderClsName) val catalog = HBaseTableCatalog(parameters) def tableName = catalog.name @@ -335,7 +336,7 @@ case class HBaseRelation ( val pushDownFilterJava = if (usePushDownColumnFilter && pushDownDynamicLogicExpression != null) { Some(new SparkSQLPushDownFilter(pushDownDynamicLogicExpression, - valueArray, requiredQualifierDefinitionList)) + valueArray, requiredQualifierDefinitionList, encoderClsName)) } else { None } @@ -402,11 +403,17 @@ case class HBaseRelation ( (superRowKeyFilter, superDynamicLogicExpression, queryValueArray) } + /** + * Because we don't assume any specific encoder/decoder, and order is inconsistent between java primitive + * type and its byte array, we have to split the predicates on some of the java primitive type into multiple + * predicates, and union these predicates together to make the predicates be performed correctly. + * For example, if we have "COLUMN < 2", we will transform it into + * "0 <= COLUMN < 2 OR Integer.MIN_VALUE <= COLUMN <= -1" + */ def transverseFilterTree(parentRowKeyFilter:RowKeyFilter, valueArray:mutable.MutableList[Array[Byte]], filter:Filter): DynamicLogicExpression = { filter match { - case EqualTo(attr, value) => val field = catalog.getField(attr) if (field != null) { @@ -420,18 +427,38 @@ case class HBaseRelation ( valueArray += byteValue } new EqualLogicExpression(attr, valueArray.length - 1, false) + + /** + * BoundRange will split the predicates into multiple byte array boundaries. + * Each boundaries is mapped into the RowKeyFilter and then is unioned by the reduce + * operation. If the data type is not supported by BoundRange, b will be None, and there is + * no operation happens on the parentRowKeyFilter. + * + * Note that because LessThan is not inclusive, thus the first bound should be exclusive, + * which is controlled by inc. + * + * The other predicates, i.e., GreaterThan/LessThanOrEqual/GreaterThanOrEqual follows + * the similar logic. + */ case LessThan(attr, value) => val field = catalog.getField(attr) if (field != null) { if (field.isRowKey) { - parentRowKeyFilter.mergeIntersect(new RowKeyFilter(null, - new ScanRange(DefaultSourceStaticUtils.getByteValue(field, - value.toString), false, - new Array[Byte](0), true))) + val b = BoundRange(value) + var inc = false + b.map(_.less.map { x => + val r = new RowKeyFilter(null, + new ScanRange(x.upper, inc, x.low, true) + ) + inc = true + r + }).map { x => + x.reduce { (i, j) => + i.mergeUnion(j) + } + }.map(parentRowKeyFilter.mergeIntersect(_)) } - val byteValue = - DefaultSourceStaticUtils.getByteValue(catalog.getField(attr), - value.toString) + val byteValue = encoder.encode(field.dt, value) valueArray += byteValue } new LessThanLogicExpression(attr, valueArray.length - 1) @@ -439,13 +466,20 @@ case class HBaseRelation ( val field = catalog.getField(attr) if (field != null) { if (field.isRowKey) { - parentRowKeyFilter.mergeIntersect(new RowKeyFilter(null, - new ScanRange(null, true, DefaultSourceStaticUtils.getByteValue(field, - value.toString), false))) + val b = BoundRange(value) + var inc = false + b.map(_.greater.map{x => + val r = new RowKeyFilter(null, + new ScanRange(x.upper, true, x.low, inc)) + inc = true + r + }).map { x => + x.reduce { (i, j) => + i.mergeUnion(j) + } + }.map(parentRowKeyFilter.mergeIntersect(_)) } - val byteValue = - DefaultSourceStaticUtils.getByteValue(field, - value.toString) + val byteValue = encoder.encode(field.dt, value) valueArray += byteValue } new GreaterThanLogicExpression(attr, valueArray.length - 1) @@ -453,14 +487,17 @@ case class HBaseRelation ( val field = catalog.getField(attr) if (field != null) { if (field.isRowKey) { - parentRowKeyFilter.mergeIntersect(new RowKeyFilter(null, - new ScanRange(DefaultSourceStaticUtils.getByteValue(field, - value.toString), true, - new Array[Byte](0), true))) + val b = BoundRange(value) + b.map(_.less.map(x => + new RowKeyFilter(null, + new ScanRange(x.upper, true, x.low, true)))) + .map { x => + x.reduce{ (i, j) => + i.mergeUnion(j) + } + }.map(parentRowKeyFilter.mergeIntersect(_)) } - val byteValue = - DefaultSourceStaticUtils.getByteValue(catalog.getField(attr), - value.toString) + val byteValue = encoder.encode(field.dt, value) valueArray += byteValue } new LessThanOrEqualLogicExpression(attr, valueArray.length - 1) @@ -468,15 +505,18 @@ case class HBaseRelation ( val field = catalog.getField(attr) if (field != null) { if (field.isRowKey) { - parentRowKeyFilter.mergeIntersect(new RowKeyFilter(null, - new ScanRange(null, true, DefaultSourceStaticUtils.getByteValue(field, - value.toString), true))) + val b = BoundRange(value) + b.map(_.greater.map(x => + new RowKeyFilter(null, + new ScanRange(x.upper, true, x.low, true)))) + .map { x => + x.reduce { (i, j) => + i.mergeUnion(j) + } + }.map(parentRowKeyFilter.mergeIntersect(_)) } - val byteValue = - DefaultSourceStaticUtils.getByteValue(catalog.getField(attr), - value.toString) + val byteValue = encoder.encode(field.dt, value) valueArray += byteValue - } new GreaterThanOrEqualLogicExpression(attr, valueArray.length - 1) case Or(left, right) => @@ -587,7 +627,7 @@ class ScanRange(var upperBound:Array[Byte], var isUpperBoundEqualTo:Boolean, var leftRange:ScanRange = null var rightRange:ScanRange = null - //First identify the Left range + // First identify the Left range // Also lower bound can't be null if (compareRange(lowerBound, other.lowerBound) < 0 || compareRange(upperBound, other.upperBound) < 0) { @@ -598,17 +638,34 @@ class ScanRange(var upperBound:Array[Byte], var isUpperBoundEqualTo:Boolean, rightRange = this } - //Then see if leftRange goes to null or if leftRange.upperBound - // upper is greater or equals to rightRange.lowerBound - if (leftRange.upperBound == null || - Bytes.compareTo(leftRange.upperBound, rightRange.lowerBound) >= 0) { - new ScanRange(leftRange.upperBound, leftRange.isUpperBoundEqualTo, rightRange.lowerBound, rightRange.isLowerBoundEqualTo) + if (hasOverlap(leftRange, rightRange)) { + // Find the upper bound and lower bound + if (compareRange(leftRange.upperBound, rightRange.upperBound) >= 0) { + new ScanRange(rightRange.upperBound, rightRange.isUpperBoundEqualTo, + rightRange.lowerBound, rightRange.isLowerBoundEqualTo) + } else { + new ScanRange(leftRange.upperBound, leftRange.isUpperBoundEqualTo, + rightRange.lowerBound, rightRange.isLowerBoundEqualTo) + } } else { null } } /** + * The leftRange.upperBound has to be larger than the rightRange's lowBound. + * Otherwise, there is no overlap. + * + * @param left: The range with the smaller lowBound + * @param right: The range with the larger lowBound + * @return Whether two ranges has overlap. + */ + + def hasOverlap(left: ScanRange, right: ScanRange): Boolean = { + compareRange(left.upperBound, right.lowerBound) >= 0 + } + + /** * Special compare logic because we can have null values * for left or right bound * @@ -1046,7 +1103,7 @@ class RowKeyFilter (currentPoint:Array[Byte] = null, * * @param other Filter to merge */ - def mergeUnion(other:RowKeyFilter): Unit = { + def mergeUnion(other:RowKeyFilter): RowKeyFilter = { other.points.foreach( p => points += p) other.ranges.foreach( otherR => { @@ -1058,6 +1115,7 @@ class RowKeyFilter (currentPoint:Array[Byte] = null, }} if (!doesOverLap) ranges.+=(otherR) }) + this } /** @@ -1066,7 +1124,7 @@ class RowKeyFilter (currentPoint:Array[Byte] = null, * * @param other Filter to merge */ - def mergeIntersect(other:RowKeyFilter): Unit = { + def mergeIntersect(other:RowKeyFilter): RowKeyFilter = { val survivingPoints = new mutable.MutableList[Array[Byte]]() val didntSurviveFirstPassPoints = new mutable.MutableList[Array[Byte]]() if (points == null || points.length == 0) { @@ -1112,6 +1170,7 @@ class RowKeyFilter (currentPoint:Array[Byte] = null, } points = survivingPoints ranges = survivingRanges + this } override def toString:String = { diff --git hbase-spark/src/main/scala/org/apache/hadoop/hbase/spark/DynamicLogicExpression.scala hbase-spark/src/main/scala/org/apache/hadoop/hbase/spark/DynamicLogicExpression.scala index fa61860..1a1d478 100644 --- hbase-spark/src/main/scala/org/apache/hadoop/hbase/spark/DynamicLogicExpression.scala +++ hbase-spark/src/main/scala/org/apache/hadoop/hbase/spark/DynamicLogicExpression.scala @@ -19,8 +19,11 @@ package org.apache.hadoop.hbase.spark import java.util +import org.apache.hadoop.hbase.spark.datasources.{BytesEncoder, JavaBytesEncoder} +import org.apache.hadoop.hbase.spark.datasources.JavaBytesEncoder.JavaBytesEncoder import org.apache.hadoop.hbase.util.Bytes - +import org.apache.spark.sql.datasources.hbase.{Field, Utils} +import org.apache.spark.sql.types._ /** * Dynamic logic for SQL push down logic there is an instance for most * common operations and a pass through for other operations not covered here @@ -38,7 +41,31 @@ trait DynamicLogicExpression { appendToExpression(strBuilder) strBuilder.toString() } + def filterOps: JavaBytesEncoder = JavaBytesEncoder.Unknown + def appendToExpression(strBuilder:StringBuilder) + + var encoder: BytesEncoder = _ + + def setEncoder(enc: BytesEncoder): DynamicLogicExpression = { + encoder = enc + this + } +} + +trait CompareTrait { + self: DynamicLogicExpression => + def columnName: String + def valueFromQueryIndex: Int + def execute(columnToCurrentRowValueMap: + util.HashMap[String, ByteArrayComparable], + valueFromQueryValueArray:Array[Array[Byte]]): Boolean = { + val currentRowValue = columnToCurrentRowValueMap.get(columnName) + val valueFromQuery = valueFromQueryValueArray(valueFromQueryIndex) + currentRowValue != null && + encoder.filter(currentRowValue.bytes, currentRowValue.offset, currentRowValue.length, + valueFromQuery, 0, valueFromQuery.length, filterOps) + } } class AndLogicExpression (val leftExpression:DynamicLogicExpression, @@ -113,59 +140,28 @@ class IsNullLogicExpression (val columnName:String, } } -class GreaterThanLogicExpression (val columnName:String, - val valueFromQueryIndex:Int) - extends DynamicLogicExpression{ - override def execute(columnToCurrentRowValueMap: - util.HashMap[String, ByteArrayComparable], - valueFromQueryValueArray:Array[Array[Byte]]): Boolean = { - val currentRowValue = columnToCurrentRowValueMap.get(columnName) - val valueFromQuery = valueFromQueryValueArray(valueFromQueryIndex) - - currentRowValue != null && - Bytes.compareTo(currentRowValue.bytes, - currentRowValue.offset, currentRowValue.length, valueFromQuery, - 0, valueFromQuery.length) > 0 - } +class GreaterThanLogicExpression (override val columnName:String, + override val valueFromQueryIndex:Int) + extends DynamicLogicExpression with CompareTrait{ + override val filterOps = JavaBytesEncoder.Greater override def appendToExpression(strBuilder: StringBuilder): Unit = { strBuilder.append(columnName + " > " + valueFromQueryIndex) } } -class GreaterThanOrEqualLogicExpression (val columnName:String, - val valueFromQueryIndex:Int) - extends DynamicLogicExpression{ - override def execute(columnToCurrentRowValueMap: - util.HashMap[String, ByteArrayComparable], - valueFromQueryValueArray:Array[Array[Byte]]): Boolean = { - val currentRowValue = columnToCurrentRowValueMap.get(columnName) - val valueFromQuery = valueFromQueryValueArray(valueFromQueryIndex) - - currentRowValue != null && - Bytes.compareTo(currentRowValue.bytes, - currentRowValue.offset, currentRowValue.length, valueFromQuery, - 0, valueFromQuery.length) >= 0 - } +class GreaterThanOrEqualLogicExpression (override val columnName:String, + override val valueFromQueryIndex:Int) + extends DynamicLogicExpression with CompareTrait{ + override val filterOps = JavaBytesEncoder.GreaterEqual override def appendToExpression(strBuilder: StringBuilder): Unit = { strBuilder.append(columnName + " >= " + valueFromQueryIndex) } } -class LessThanLogicExpression (val columnName:String, - val valueFromQueryIndex:Int) - extends DynamicLogicExpression{ - override def execute(columnToCurrentRowValueMap: - util.HashMap[String, ByteArrayComparable], - valueFromQueryValueArray:Array[Array[Byte]]): Boolean = { - val currentRowValue = columnToCurrentRowValueMap.get(columnName) - val valueFromQuery = valueFromQueryValueArray(valueFromQueryIndex) - - currentRowValue != null && - Bytes.compareTo(currentRowValue.bytes, - currentRowValue.offset, currentRowValue.length, valueFromQuery, - 0, valueFromQuery.length) < 0 - } - +class LessThanLogicExpression (override val columnName:String, + override val valueFromQueryIndex:Int) + extends DynamicLogicExpression with CompareTrait { + override val filterOps = JavaBytesEncoder.Less override def appendToExpression(strBuilder: StringBuilder): Unit = { strBuilder.append(columnName + " < " + valueFromQueryIndex) } @@ -173,19 +169,8 @@ class LessThanLogicExpression (val columnName:String, class LessThanOrEqualLogicExpression (val columnName:String, val valueFromQueryIndex:Int) - extends DynamicLogicExpression{ - override def execute(columnToCurrentRowValueMap: - util.HashMap[String, ByteArrayComparable], - valueFromQueryValueArray:Array[Array[Byte]]): Boolean = { - val currentRowValue = columnToCurrentRowValueMap.get(columnName) - val valueFromQuery = valueFromQueryValueArray(valueFromQueryIndex) - - currentRowValue != null && - Bytes.compareTo(currentRowValue.bytes, - currentRowValue.offset, currentRowValue.length, valueFromQuery, - 0, valueFromQuery.length) <= 0 - } - + extends DynamicLogicExpression with CompareTrait{ + override val filterOps = JavaBytesEncoder.LessEqual override def appendToExpression(strBuilder: StringBuilder): Unit = { strBuilder.append(columnName + " <= " + valueFromQueryIndex) } @@ -197,58 +182,66 @@ class PassThroughLogicExpression() extends DynamicLogicExpression { valueFromQueryValueArray: Array[Array[Byte]]): Boolean = true override def appendToExpression(strBuilder: StringBuilder): Unit = { - strBuilder.append("Pass") + // Fix the offset bug by add dummy to avoid crash the region server. + // because in the DynamicLogicExpressionBuilder.build function, the command is always retrieved from offset + 1 as below + // val command = expressionArray(offSet + 1) + // we have to padding it so that `Pass` is on the right offset. + strBuilder.append("dummy Pass -1") } } object DynamicLogicExpressionBuilder { - def build(expressionString:String): DynamicLogicExpression = { + def build(expressionString: String, encoder: BytesEncoder): DynamicLogicExpression = { - val expressionAndOffset = build(expressionString.split(' '), 0) + val expressionAndOffset = build(expressionString.split(' '), 0, encoder) expressionAndOffset._1 } private def build(expressionArray:Array[String], - offSet:Int): (DynamicLogicExpression, Int) = { - if (expressionArray(offSet).equals("(")) { - val left = build(expressionArray, offSet + 1) - val right = build(expressionArray, left._2 + 1) - if (expressionArray(left._2).equals("AND")) { - (new AndLogicExpression(left._1, right._1), right._2 + 1) - } else if (expressionArray(left._2).equals("OR")) { - (new OrLogicExpression(left._1, right._1), right._2 + 1) - } else { - throw new Throwable("Unknown gate:" + expressionArray(left._2)) - } - } else { - val command = expressionArray(offSet + 1) - if (command.equals("<")) { - (new LessThanLogicExpression(expressionArray(offSet), - expressionArray(offSet + 2).toInt), offSet + 3) - } else if (command.equals("<=")) { - (new LessThanOrEqualLogicExpression(expressionArray(offSet), - expressionArray(offSet + 2).toInt), offSet + 3) - } else if (command.equals(">")) { - (new GreaterThanLogicExpression(expressionArray(offSet), - expressionArray(offSet + 2).toInt), offSet + 3) - } else if (command.equals(">=")) { - (new GreaterThanOrEqualLogicExpression(expressionArray(offSet), - expressionArray(offSet + 2).toInt), offSet + 3) - } else if (command.equals("==")) { - (new EqualLogicExpression(expressionArray(offSet), - expressionArray(offSet + 2).toInt, false), offSet + 3) - } else if (command.equals("!=")) { - (new EqualLogicExpression(expressionArray(offSet), - expressionArray(offSet + 2).toInt, true), offSet + 3) - } else if (command.equals("isNull")) { - (new IsNullLogicExpression(expressionArray(offSet), false), offSet + 2) - } else if (command.equals("isNotNull")) { - (new IsNullLogicExpression(expressionArray(offSet), true), offSet + 2) - } else if (command.equals("Pass")) { - (new PassThroughLogicExpression, offSet + 2) + offSet:Int, encoder: BytesEncoder): (DynamicLogicExpression, Int) = { + val expr = { + if (expressionArray(offSet).equals("(")) { + val left = build(expressionArray, offSet + 1, encoder) + val right = build(expressionArray, left._2 + 1, encoder) + if (expressionArray(left._2).equals("AND")) { + (new AndLogicExpression(left._1, right._1), right._2 + 1) + } else if (expressionArray(left._2).equals("OR")) { + (new OrLogicExpression(left._1, right._1), right._2 + 1) + } else { + throw new Throwable("Unknown gate:" + expressionArray(left._2)) + } } else { - throw new Throwable("Unknown logic command:" + command) + val command = expressionArray(offSet + 1) + if (command.equals("<")) { + (new LessThanLogicExpression(expressionArray(offSet), + expressionArray(offSet + 2).toInt), offSet + 3) + } else if (command.equals("<=")) { + (new LessThanOrEqualLogicExpression(expressionArray(offSet), + expressionArray(offSet + 2).toInt), offSet + 3) + } else if (command.equals(">")) { + (new GreaterThanLogicExpression(expressionArray(offSet), + expressionArray(offSet + 2).toInt), offSet + 3) + } else if (command.equals(">=")) { + (new GreaterThanOrEqualLogicExpression(expressionArray(offSet), + expressionArray(offSet + 2).toInt), offSet + 3) + } else if (command.equals("==")) { + (new EqualLogicExpression(expressionArray(offSet), + expressionArray(offSet + 2).toInt, false), offSet + 3) + } else if (command.equals("!=")) { + (new EqualLogicExpression(expressionArray(offSet), + expressionArray(offSet + 2).toInt, true), offSet + 3) + } else if (command.equals("isNull")) { + (new IsNullLogicExpression(expressionArray(offSet), false), offSet + 2) + } else if (command.equals("isNotNull")) { + (new IsNullLogicExpression(expressionArray(offSet), true), offSet + 2) + } else if (command.equals("Pass")) { + (new PassThroughLogicExpression, offSet + 3) + } else { + throw new Throwable("Unknown logic command:" + command) + } } } + expr._1.setEncoder(encoder) + expr } } \ No newline at end of file diff --git hbase-spark/src/main/scala/org/apache/hadoop/hbase/spark/datasources/BoundRange.scala hbase-spark/src/main/scala/org/apache/hadoop/hbase/spark/datasources/BoundRange.scala new file mode 100644 index 0000000..b6b8e03 --- /dev/null +++ hbase-spark/src/main/scala/org/apache/hadoop/hbase/spark/datasources/BoundRange.scala @@ -0,0 +1,147 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.hadoop.hbase.spark.datasources + +import org.apache.hadoop.hbase.util.Bytes +import org.apache.hadoop.hbase.spark.hbase._ +import org.apache.spark.Logging +import org.apache.spark.unsafe.types.UTF8String + +/** + * The ranges for the data type whose size is known. Whether the bound is inclusive + * or exclusive is undefind, and upper to the caller to decide. + * + * @param low: the lower bound of the range. + * @param upper: the upper bound of the range. + */ +case class BoundRange(low: Array[Byte],upper: Array[Byte]) + +/** + * The class identifies the ranges for a java primitive type. The caller needs + * to decide the bound is either inclusive or exclusive on its own. + * information + * + * @param less: the set of ranges for LessThan/LessOrEqualThan + * @param greater: the set of ranges for GreaterThan/GreaterThanOrEqualTo + * @param value: the byte array of the original value + */ +case class BoundRanges(less: Array[BoundRange], greater: Array[BoundRange], value: Array[Byte]) + +/** + * Evaluate the java primitive type and return the BoundRanges. For one value, it may have + * multiple output ranges because of the inconsistency of order between java primitive type + * and its byte array order. + * + * For short, integer, and long, the order of number is consistent with byte array order + * if two number has the same sign bit. But the negative number is larger than positive + * number in byte array. + * + * For double and float, the order of positive number is consistent with its byte array order. + * But the order of negative number is the reverse order of byte array. Please refer to IEEE-754 + * and https://en.wikipedia.org/wiki/Single-precision_floating-point_format + */ +object BoundRange extends Logging{ + def apply(in: Any): Option[BoundRanges] = in match { + case a: Integer => + val b = Bytes.toBytes(a) + if (a >= 0) { + logDebug(s"range is 0 to $a and ${Integer.MIN_VALUE} to -1") + Some(BoundRanges( + Array(BoundRange(Bytes.toBytes(0: Int), b), + BoundRange(Bytes.toBytes(Integer.MIN_VALUE), Bytes.toBytes(-1: Int))), + Array(BoundRange(b, Bytes.toBytes(Integer.MAX_VALUE))), b)) + } else { + Some(BoundRanges( + Array(BoundRange(Bytes.toBytes(Integer.MIN_VALUE), b)), + Array(BoundRange(b, Bytes.toBytes(-1: Integer)), + BoundRange(Bytes.toBytes(0: Int), Bytes.toBytes(Integer.MAX_VALUE))), b)) + } + case a: Long => + val b = Bytes.toBytes(a) + if (a >= 0) { + Some(BoundRanges( + Array(BoundRange(Bytes.toBytes(0: Long), b), + BoundRange(Bytes.toBytes(Long.MinValue), Bytes.toBytes(-1: Long))), + Array(BoundRange(b, Bytes.toBytes(Long.MaxValue))), b)) + } else { + Some(BoundRanges( + Array(BoundRange(Bytes.toBytes(Long.MinValue), b)), + Array(BoundRange(b, Bytes.toBytes(-1: Long)), + BoundRange(Bytes.toBytes(0: Long), Bytes.toBytes(Long.MaxValue))), b)) + } + case a: Short => + val b = Bytes.toBytes(a) + if (a >= 0) { + Some(BoundRanges( + Array(BoundRange(Bytes.toBytes(0: Short), b), + BoundRange(Bytes.toBytes(Short.MinValue), Bytes.toBytes(-1: Short))), + Array(BoundRange(b, Bytes.toBytes(Short.MaxValue))), b)) + } else { + Some(BoundRanges( + Array(BoundRange(Bytes.toBytes(Short.MinValue), b)), + Array(BoundRange(b, Bytes.toBytes(-1: Short)), + BoundRange(Bytes.toBytes(0: Short), Bytes.toBytes(Short.MaxValue))), b)) + } + case a: Double => + val b = Bytes.toBytes(a) + if (a >= 0.0f) { + Some(BoundRanges( + Array(BoundRange(Bytes.toBytes(0.0d), b), + BoundRange(Bytes.toBytes(-0.0d), Bytes.toBytes(Double.MinValue))), + Array(BoundRange(b, Bytes.toBytes(Double.MaxValue))), b)) + } else { + Some(BoundRanges( + Array(BoundRange(b, Bytes.toBytes(Double.MinValue))), + Array(BoundRange(Bytes.toBytes(-0.0d), b), + BoundRange(Bytes.toBytes(0.0d), Bytes.toBytes(Double.MaxValue))), b)) + } + case a: Float => + val b = Bytes.toBytes(a) + if (a >= 0.0f) { + Some(BoundRanges( + Array(BoundRange(Bytes.toBytes(0.0f), b), + BoundRange(Bytes.toBytes(-0.0f), Bytes.toBytes(Float.MinValue))), + Array(BoundRange(b, Bytes.toBytes(Float.MaxValue))), b)) + } else { + Some(BoundRanges( + Array(BoundRange(b, Bytes.toBytes(Float.MinValue))), + Array(BoundRange(Bytes.toBytes(-0.0f), b), + BoundRange(Bytes.toBytes(0.0f), Bytes.toBytes(Float.MaxValue))), b)) + } + case a: Array[Byte] => + Some(BoundRanges( + Array(BoundRange(bytesMin, a)), + Array(BoundRange(a, bytesMax)), a)) + case a: Byte => + val b = Array(a) + Some(BoundRanges( + Array(BoundRange(bytesMin, b)), + Array(BoundRange(b, bytesMax)), b)) + case a: String => + val b = Bytes.toBytes(a) + Some(BoundRanges( + Array(BoundRange(bytesMin, b)), + Array(BoundRange(b, bytesMax)), b)) + case a: UTF8String => + val b = a.getBytes + Some(BoundRanges( + Array(BoundRange(bytesMin, b)), + Array(BoundRange(b, bytesMax)), b)) + case _ => None + } +} diff --git hbase-spark/src/main/scala/org/apache/hadoop/hbase/spark/datasources/HBaseSparkConf.scala hbase-spark/src/main/scala/org/apache/hadoop/hbase/spark/datasources/HBaseSparkConf.scala index be2af30..d9ee494 100644 --- hbase-spark/src/main/scala/org/apache/hadoop/hbase/spark/datasources/HBaseSparkConf.scala +++ hbase-spark/src/main/scala/org/apache/hadoop/hbase/spark/datasources/HBaseSparkConf.scala @@ -41,4 +41,6 @@ object HBaseSparkConf{ val MIN_TIMESTAMP = "hbase.spark.query.minTimestamp" val MAX_TIMESTAMP = "hbase.spark.query.maxTimestamp" val MAX_VERSIONS = "hbase.spark.query.maxVersions" + val ENCODER = "hbase.spark.query.encoder" + val defaultEncoder = classOf[NaiveEncoder].getCanonicalName } diff --git hbase-spark/src/main/scala/org/apache/hadoop/hbase/spark/datasources/JavaBytesEncoder.scala hbase-spark/src/main/scala/org/apache/hadoop/hbase/spark/datasources/JavaBytesEncoder.scala new file mode 100644 index 0000000..90b2d91 --- /dev/null +++ hbase-spark/src/main/scala/org/apache/hadoop/hbase/spark/datasources/JavaBytesEncoder.scala @@ -0,0 +1,199 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.hadoop.hbase.spark.datasources + +import org.apache.hadoop.hbase.spark.datasources.JavaBytesEncoder.JavaBytesEncoder +import org.apache.hadoop.hbase.util.Bytes +import org.apache.spark.sql.types._ + +/** + * The trait to support plugin architecture for different encoder/decoder. + * encoder is used for serialize the data type to byte arrway and the filter is + * used to filter out the unnecessary records. + */ +trait BytesEncoder { + def encode(dt: DataType, value: Any): Array[Byte] + + /** + * The function performing real filtering operations. The format of filterBytes depends on the + * implementation of the encoder/decoder. + * + * @param input: the current input byte array that needs to be filtered out + * @param offset1: the starting offset of the input byte array. + * @param length1: the length of the input byte array. + * @param filterBytes: the byte array provided by query condition. + * @param offset2: the starting offset in the filterBytes. + * @param length2: the length of the bytes in the filterBytes + * @param ops: The operation of the filter operator. + * @return true: the record satisfies the predicates + * false: the record does not satisfies the predicates. + */ + def filter(input: Array[Byte], offset1: Int, length1: Int, + filterBytes: Array[Byte], offset2: Int, length2: Int, + ops: JavaBytesEncoder): Boolean +} + +/** + * This is the naive non-order preserving encoder/decoder. + * Due to the inconsistency of the order between java primitive types + * and their bytearray. The data type has to be passed in so that the filter + * can work correctly, which is done by wrapping the type into the first byte + * of the serialized array. + */ +class NaiveEncoder extends BytesEncoder { + var code = 0 + def nextCode: Byte = { + code += 1 + (code - 1).asInstanceOf[Byte] + } + val BooleanEnc = nextCode + val ShortEnc = nextCode + val IntEnc = nextCode + val LongEnc = nextCode + val FloatEnc = nextCode + val DoubleEnc = nextCode + val StringEnc = nextCode + val BinaryEnc = nextCode + val TimestampEnc = nextCode + val UnknownEnc = nextCode + + def compare(c: Int, ops: JavaBytesEncoder): Boolean = { + ops match { + case JavaBytesEncoder.Greater => c > 0 + case JavaBytesEncoder.GreaterEqual => c >= 0 + case JavaBytesEncoder.Less => c < 0 + case JavaBytesEncoder.LessEqual => c <= 0 + } + } + + /** + * encode the data type into byte array. Note that it is a naive implementation with the + * data type byte appending to the head of the serialized byte array. + * + * @param dt: The data type of the input + * @param value: the value of the input + * @return the byte array with the first byte indicating the data type. + */ + override def encode(dt: DataType, + value: Any): Array[Byte] = { + dt match { + case BooleanType => + val result = new Array[Byte](Bytes.SIZEOF_BOOLEAN + 1) + result(0) = BooleanEnc + value.asInstanceOf[Boolean] match { + case true => result(1) = -1: Byte + case false => result(1) = 0: Byte + } + result + case ShortType => + val result = new Array[Byte](Bytes.SIZEOF_SHORT + 1) + result(0) = ShortEnc + Bytes.putShort(result, 1, value.asInstanceOf[Short]) + result + case IntegerType => + val result = new Array[Byte](Bytes.SIZEOF_INT + 1) + result(0) = IntEnc + Bytes.putInt(result, 1, value.asInstanceOf[Int]) + result + case LongType|TimestampType => + val result = new Array[Byte](Bytes.SIZEOF_LONG + 1) + result(0) = LongEnc + Bytes.putLong(result, 1, value.asInstanceOf[Long]) + result + case FloatType => + val result = new Array[Byte](Bytes.SIZEOF_FLOAT + 1) + result(0) = FloatEnc + Bytes.putFloat(result, 1, value.asInstanceOf[Float]) + result + case DoubleType => + val result = new Array[Byte](Bytes.SIZEOF_DOUBLE + 1) + result(0) = DoubleEnc + Bytes.putDouble(result, 1, value.asInstanceOf[Double]) + result + case BinaryType => + val v = value.asInstanceOf[Array[Bytes]] + val result = new Array[Byte](v.length + 1) + result(0) = BinaryEnc + System.arraycopy(v, 0, result, 1, v.length) + result + case StringType => + val bytes = Bytes.toBytes(value.asInstanceOf[String]) + val result = new Array[Byte](bytes.length + 1) + result(0) = StringEnc + System.arraycopy(bytes, 0, result, 1, bytes.length) + result + case _ => + val bytes = Bytes.toBytes(value.toString) + val result = new Array[Byte](bytes.length + 1) + result(0) = UnknownEnc + System.arraycopy(bytes, 0, result, 1, bytes.length) + result + } + } + + override def filter(input: Array[Byte], offset1: Int, length1: Int, + filterBytes: Array[Byte], offset2: Int, length2: Int, + ops: JavaBytesEncoder): Boolean = { + filterBytes(offset2) match { + case ShortEnc => + val in = Bytes.toShort(input, offset1) + val value = Bytes.toShort(filterBytes, offset2 + 1) + compare(in.compareTo(value), ops) + case IntEnc => + val in = Bytes.toInt(input, offset1) + val value = Bytes.toInt(filterBytes, offset2 + 1) + compare(in.compareTo(value), ops) + case LongEnc | TimestampEnc => + val in = Bytes.toInt(input, offset1) + val value = Bytes.toInt(filterBytes, offset2 + 1) + compare(in.compareTo(value), ops) + case FloatEnc => + val in = Bytes.toFloat(input, offset1) + val value = Bytes.toFloat(filterBytes, offset2 + 1) + compare(in.compareTo(value), ops) + case DoubleEnc => + val in = Bytes.toDouble(input, offset1) + val value = Bytes.toDouble(filterBytes, offset2 + 1) + compare(in.compareTo(value), ops) + case _ => + // for String, Byte, Binary, Boolean and other types + // we can use the order of byte array directly. + compare( + Bytes.compareTo(input, offset1, length1, filterBytes, offset2 + 1, length2 - 1), ops) + } + } +} + +object JavaBytesEncoder extends Enumeration { + type JavaBytesEncoder = Value + val Greater, GreaterEqual, Less, LessEqual, Equal, Unknown = Value + + /** + * create the encoder/decoder + * + * @param clsName: the class name of the encoder/decoder class + * @return the instance of the encoder plugin. + */ + def create(clsName: String): BytesEncoder = { + try { + Class.forName(clsName).newInstance.asInstanceOf[BytesEncoder] + } catch { + case _: Throwable => new NaiveEncoder() + } + } +} \ No newline at end of file diff --git hbase-spark/src/main/scala/org/apache/hadoop/hbase/spark/datasources/package.scala hbase-spark/src/main/scala/org/apache/hadoop/hbase/spark/datasources/package.scala index 4ff0413..ce7b55a 100644 --- hbase-spark/src/main/scala/org/apache/hadoop/hbase/spark/datasources/package.scala +++ hbase-spark/src/main/scala/org/apache/hadoop/hbase/spark/datasources/package.scala @@ -23,6 +23,8 @@ import scala.math.Ordering package object hbase { type HBaseType = Array[Byte] + def bytesMin = new Array[Byte](0) + def bytesMax = null val ByteMax = -1.asInstanceOf[Byte] val ByteMin = 0.asInstanceOf[Byte] val ord: Ordering[HBaseType] = new Ordering[HBaseType] { diff --git hbase-spark/src/main/scala/org/apache/spark/sql/datasources/hbase/HBaseTableCatalog.scala hbase-spark/src/main/scala/org/apache/spark/sql/datasources/hbase/HBaseTableCatalog.scala index 831c7de..c2d611f 100644 --- hbase-spark/src/main/scala/org/apache/spark/sql/datasources/hbase/HBaseTableCatalog.scala +++ hbase-spark/src/main/scala/org/apache/spark/sql/datasources/hbase/HBaseTableCatalog.scala @@ -156,7 +156,7 @@ case class HBaseTableCatalog( def get(key: String) = params.get(key) // Setup the start and length for each dimension of row key at runtime. - def dynSetupRowKey(rowKey: HBaseType) { + def dynSetupRowKey(rowKey: Array[Byte]) { logDebug(s"length: ${rowKey.length}") if(row.varLength) { var start = 0 diff --git hbase-spark/src/test/scala/org/apache/hadoop/hbase/spark/DefaultSourceSuite.scala hbase-spark/src/test/scala/org/apache/hadoop/hbase/spark/DefaultSourceSuite.scala index 4312b38..0f8baed 100644 --- hbase-spark/src/test/scala/org/apache/hadoop/hbase/spark/DefaultSourceSuite.scala +++ hbase-spark/src/test/scala/org/apache/hadoop/hbase/spark/DefaultSourceSuite.scala @@ -329,14 +329,7 @@ BeforeAndAfterEach with BeforeAndAfterAll with Logging { equals("( KEY_FIELD < 0 OR KEY_FIELD > 1 )")) assert(executionRules.rowKeyFilter.points.size == 0) - assert(executionRules.rowKeyFilter.ranges.size == 1) - - val scanRange1 = executionRules.rowKeyFilter.ranges.get(0).get - assert(Bytes.equals(scanRange1.lowerBound,Bytes.toBytes(""))) - assert(scanRange1.upperBound == null) - assert(scanRange1.isLowerBoundEqualTo) - assert(scanRange1.isUpperBoundEqualTo) - + assert(executionRules.rowKeyFilter.ranges.size == 2) assert(results.length == 5) } @@ -358,18 +351,14 @@ BeforeAndAfterEach with BeforeAndAfterAll with Logging { assert(executionRules.rowKeyFilter.points.size == 0) - assert(executionRules.rowKeyFilter.ranges.size == 2) + assert(executionRules.rowKeyFilter.ranges.size == 3) val scanRange1 = executionRules.rowKeyFilter.ranges.get(0).get - assert(Bytes.equals(scanRange1.lowerBound,Bytes.toBytes(""))) assert(Bytes.equals(scanRange1.upperBound, Bytes.toBytes(2))) assert(scanRange1.isLowerBoundEqualTo) assert(!scanRange1.isUpperBoundEqualTo) val scanRange2 = executionRules.rowKeyFilter.ranges.get(1).get - assert(Bytes.equals(scanRange2.lowerBound, Bytes.toBytes(4))) - assert(scanRange2.upperBound == null) - assert(!scanRange2.isLowerBoundEqualTo) assert(scanRange2.isUpperBoundEqualTo) assert(results.length == 2) diff --git hbase-spark/src/test/scala/org/apache/hadoop/hbase/spark/DynamicLogicExpressionSuite.scala hbase-spark/src/test/scala/org/apache/hadoop/hbase/spark/DynamicLogicExpressionSuite.scala index 3140ebd..ff4201c 100644 --- hbase-spark/src/test/scala/org/apache/hadoop/hbase/spark/DynamicLogicExpressionSuite.scala +++ hbase-spark/src/test/scala/org/apache/hadoop/hbase/spark/DynamicLogicExpressionSuite.scala @@ -19,13 +19,17 @@ package org.apache.hadoop.hbase.spark import java.util +import org.apache.hadoop.hbase.spark.datasources.{HBaseSparkConf, JavaBytesEncoder} import org.apache.hadoop.hbase.util.Bytes import org.apache.spark.Logging +import org.apache.spark.sql.types._ import org.scalatest.{BeforeAndAfterAll, BeforeAndAfterEach, FunSuite} class DynamicLogicExpressionSuite extends FunSuite with BeforeAndAfterEach with BeforeAndAfterAll with Logging { + val encoder = JavaBytesEncoder.create(HBaseSparkConf.defaultEncoder) + test("Basic And Test") { val leftLogic = new LessThanLogicExpression("Col1", 0) val rightLogic = new GreaterThanLogicExpression("Col1", 1) @@ -35,33 +39,33 @@ BeforeAndAfterEach with BeforeAndAfterAll with Logging { columnToCurrentRowValueMap.put("Col1", new ByteArrayComparable(Bytes.toBytes(10))) val valueFromQueryValueArray = new Array[Array[Byte]](2) - valueFromQueryValueArray(0) = Bytes.toBytes(15) - valueFromQueryValueArray(1) = Bytes.toBytes(5) + valueFromQueryValueArray(0) = encoder.encode(IntegerType, 15) + valueFromQueryValueArray(1) = encoder.encode(IntegerType, 5) assert(andLogic.execute(columnToCurrentRowValueMap, valueFromQueryValueArray)) - valueFromQueryValueArray(0) = Bytes.toBytes(10) - valueFromQueryValueArray(1) = Bytes.toBytes(5) + valueFromQueryValueArray(0) = encoder.encode(IntegerType, 10) + valueFromQueryValueArray(1) = encoder.encode(IntegerType, 5) assert(!andLogic.execute(columnToCurrentRowValueMap, valueFromQueryValueArray)) - valueFromQueryValueArray(0) = Bytes.toBytes(15) - valueFromQueryValueArray(1) = Bytes.toBytes(10) + valueFromQueryValueArray(0) = encoder.encode(IntegerType, 15) + valueFromQueryValueArray(1) = encoder.encode(IntegerType, 10) assert(!andLogic.execute(columnToCurrentRowValueMap, valueFromQueryValueArray)) val expressionString = andLogic.toExpressionString assert(expressionString.equals("( Col1 < 0 AND Col1 > 1 )")) - val builtExpression = DynamicLogicExpressionBuilder.build(expressionString) - valueFromQueryValueArray(0) = Bytes.toBytes(15) - valueFromQueryValueArray(1) = Bytes.toBytes(5) + val builtExpression = DynamicLogicExpressionBuilder.build(expressionString, encoder) + valueFromQueryValueArray(0) = encoder.encode(IntegerType, 15) + valueFromQueryValueArray(1) = encoder.encode(IntegerType, 5) assert(builtExpression.execute(columnToCurrentRowValueMap, valueFromQueryValueArray)) - valueFromQueryValueArray(0) = Bytes.toBytes(10) - valueFromQueryValueArray(1) = Bytes.toBytes(5) + valueFromQueryValueArray(0) = encoder.encode(IntegerType, 10) + valueFromQueryValueArray(1) = encoder.encode(IntegerType, 5) assert(!builtExpression.execute(columnToCurrentRowValueMap, valueFromQueryValueArray)) - valueFromQueryValueArray(0) = Bytes.toBytes(15) - valueFromQueryValueArray(1) = Bytes.toBytes(10) + valueFromQueryValueArray(0) = encoder.encode(IntegerType, 15) + valueFromQueryValueArray(1) = encoder.encode(IntegerType, 10) assert(!builtExpression.execute(columnToCurrentRowValueMap, valueFromQueryValueArray)) } @@ -75,41 +79,41 @@ BeforeAndAfterEach with BeforeAndAfterAll with Logging { columnToCurrentRowValueMap.put("Col1", new ByteArrayComparable(Bytes.toBytes(10))) val valueFromQueryValueArray = new Array[Array[Byte]](2) - valueFromQueryValueArray(0) = Bytes.toBytes(15) - valueFromQueryValueArray(1) = Bytes.toBytes(5) + valueFromQueryValueArray(0) = encoder.encode(IntegerType, 15) + valueFromQueryValueArray(1) = encoder.encode(IntegerType, 5) assert(OrLogic.execute(columnToCurrentRowValueMap, valueFromQueryValueArray)) - valueFromQueryValueArray(0) = Bytes.toBytes(10) - valueFromQueryValueArray(1) = Bytes.toBytes(5) + valueFromQueryValueArray(0) = encoder.encode(IntegerType, 10) + valueFromQueryValueArray(1) = encoder.encode(IntegerType, 5) assert(OrLogic.execute(columnToCurrentRowValueMap, valueFromQueryValueArray)) - valueFromQueryValueArray(0) = Bytes.toBytes(15) - valueFromQueryValueArray(1) = Bytes.toBytes(10) + valueFromQueryValueArray(0) = encoder.encode(IntegerType, 15) + valueFromQueryValueArray(1) = encoder.encode(IntegerType, 10) assert(OrLogic.execute(columnToCurrentRowValueMap, valueFromQueryValueArray)) - valueFromQueryValueArray(0) = Bytes.toBytes(10) - valueFromQueryValueArray(1) = Bytes.toBytes(10) + valueFromQueryValueArray(0) = encoder.encode(IntegerType, 10) + valueFromQueryValueArray(1) = encoder.encode(IntegerType, 10) assert(!OrLogic.execute(columnToCurrentRowValueMap, valueFromQueryValueArray)) val expressionString = OrLogic.toExpressionString assert(expressionString.equals("( Col1 < 0 OR Col1 > 1 )")) - val builtExpression = DynamicLogicExpressionBuilder.build(expressionString) - valueFromQueryValueArray(0) = Bytes.toBytes(15) - valueFromQueryValueArray(1) = Bytes.toBytes(5) + val builtExpression = DynamicLogicExpressionBuilder.build(expressionString, encoder) + valueFromQueryValueArray(0) = encoder.encode(IntegerType, 15) + valueFromQueryValueArray(1) = encoder.encode(IntegerType, 5) assert(builtExpression.execute(columnToCurrentRowValueMap, valueFromQueryValueArray)) - valueFromQueryValueArray(0) = Bytes.toBytes(10) - valueFromQueryValueArray(1) = Bytes.toBytes(5) + valueFromQueryValueArray(0) = encoder.encode(IntegerType, 10) + valueFromQueryValueArray(1) = encoder.encode(IntegerType, 5) assert(builtExpression.execute(columnToCurrentRowValueMap, valueFromQueryValueArray)) - valueFromQueryValueArray(0) = Bytes.toBytes(15) - valueFromQueryValueArray(1) = Bytes.toBytes(10) + valueFromQueryValueArray(0) = encoder.encode(IntegerType, 15) + valueFromQueryValueArray(1) = encoder.encode(IntegerType, 10) assert(builtExpression.execute(columnToCurrentRowValueMap, valueFromQueryValueArray)) - valueFromQueryValueArray(0) = Bytes.toBytes(10) - valueFromQueryValueArray(1) = Bytes.toBytes(10) + valueFromQueryValueArray(0) = encoder.encode(IntegerType, 10) + valueFromQueryValueArray(1) = encoder.encode(IntegerType, 10) assert(!builtExpression.execute(columnToCurrentRowValueMap, valueFromQueryValueArray)) } @@ -127,40 +131,40 @@ BeforeAndAfterEach with BeforeAndAfterAll with Logging { val valueFromQueryValueArray = new Array[Array[Byte]](1) //great than - valueFromQueryValueArray(0) = Bytes.toBytes(10) + valueFromQueryValueArray(0) = encoder.encode(IntegerType, 10) assert(!greaterLogic.execute(columnToCurrentRowValueMap, valueFromQueryValueArray)) - valueFromQueryValueArray(0) = Bytes.toBytes(20) + valueFromQueryValueArray(0) = encoder.encode(IntegerType, 20) assert(!greaterLogic.execute(columnToCurrentRowValueMap, valueFromQueryValueArray)) //great than and equal - valueFromQueryValueArray(0) = Bytes.toBytes(5) + valueFromQueryValueArray(0) = encoder.encode(IntegerType, 5) assert(greaterAndEqualLogic.execute(columnToCurrentRowValueMap, valueFromQueryValueArray)) - valueFromQueryValueArray(0) = Bytes.toBytes(10) + valueFromQueryValueArray(0) = encoder.encode(IntegerType, 10) assert(greaterAndEqualLogic.execute(columnToCurrentRowValueMap, valueFromQueryValueArray)) - valueFromQueryValueArray(0) = Bytes.toBytes(20) + valueFromQueryValueArray(0) = encoder.encode(IntegerType, 20) assert(!greaterAndEqualLogic.execute(columnToCurrentRowValueMap, valueFromQueryValueArray)) //less than - valueFromQueryValueArray(0) = Bytes.toBytes(10) + valueFromQueryValueArray(0) = encoder.encode(IntegerType, 10) assert(!lessLogic.execute(columnToCurrentRowValueMap, valueFromQueryValueArray)) - valueFromQueryValueArray(0) = Bytes.toBytes(5) + valueFromQueryValueArray(0) = encoder.encode(IntegerType, 5) assert(!lessLogic.execute(columnToCurrentRowValueMap, valueFromQueryValueArray)) //less than and equal - valueFromQueryValueArray(0) = Bytes.toBytes(20) + valueFromQueryValueArray(0) = encoder.encode(IntegerType, 20) assert(lessAndEqualLogic.execute(columnToCurrentRowValueMap, valueFromQueryValueArray)) - valueFromQueryValueArray(0) = Bytes.toBytes(20) + valueFromQueryValueArray(0) = encoder.encode(IntegerType, 20) assert(lessAndEqualLogic.execute(columnToCurrentRowValueMap, valueFromQueryValueArray)) - valueFromQueryValueArray(0) = Bytes.toBytes(10) + valueFromQueryValueArray(0) = encoder.encode(IntegerType, 10) assert(lessAndEqualLogic.execute(columnToCurrentRowValueMap, valueFromQueryValueArray)) //equal too @@ -183,8 +187,137 @@ BeforeAndAfterEach with BeforeAndAfterAll with Logging { valueFromQueryValueArray(0) = Bytes.toBytes(5) assert(passThrough.execute(columnToCurrentRowValueMap, valueFromQueryValueArray)) + } + + + test("Double Type") { + val leftLogic = new LessThanLogicExpression("Col1", 0) + val rightLogic = new GreaterThanLogicExpression("Col1", 1) + val andLogic = new AndLogicExpression(leftLogic, rightLogic) + + val columnToCurrentRowValueMap = new util.HashMap[String, ByteArrayComparable]() + + columnToCurrentRowValueMap.put("Col1", new ByteArrayComparable(Bytes.toBytes(-4.0d))) + val valueFromQueryValueArray = new Array[Array[Byte]](2) + valueFromQueryValueArray(0) = encoder.encode(DoubleType, 15.0d) + valueFromQueryValueArray(1) = encoder.encode(DoubleType, -5.0d) + assert(andLogic.execute(columnToCurrentRowValueMap, valueFromQueryValueArray)) + + valueFromQueryValueArray(0) = encoder.encode(DoubleType, 10.0d) + valueFromQueryValueArray(1) = encoder.encode(DoubleType, -1.0d) + assert(!andLogic.execute(columnToCurrentRowValueMap, valueFromQueryValueArray)) + + valueFromQueryValueArray(0) = encoder.encode(DoubleType, -10.0d) + valueFromQueryValueArray(1) = encoder.encode(DoubleType, -20.0d) + assert(!andLogic.execute(columnToCurrentRowValueMap, valueFromQueryValueArray)) + + val expressionString = andLogic.toExpressionString + // Note that here 0 and 1 is index, instead of value. + assert(expressionString.equals("( Col1 < 0 AND Col1 > 1 )")) + + val builtExpression = DynamicLogicExpressionBuilder.build(expressionString, encoder) + valueFromQueryValueArray(0) = encoder.encode(DoubleType, 15.0d) + valueFromQueryValueArray(1) = encoder.encode(DoubleType, -5.0d) + assert(builtExpression.execute(columnToCurrentRowValueMap, valueFromQueryValueArray)) + + valueFromQueryValueArray(0) = encoder.encode(DoubleType, 10.0d) + valueFromQueryValueArray(1) = encoder.encode(DoubleType, -1.0d) + assert(!builtExpression.execute(columnToCurrentRowValueMap, valueFromQueryValueArray)) + + valueFromQueryValueArray(0) = encoder.encode(DoubleType, -10.0d) + valueFromQueryValueArray(1) = encoder.encode(DoubleType, -20.0d) + assert(!builtExpression.execute(columnToCurrentRowValueMap, valueFromQueryValueArray)) + } + + test("Float Type") { + val leftLogic = new LessThanLogicExpression("Col1", 0) + val rightLogic = new GreaterThanLogicExpression("Col1", 1) + val andLogic = new AndLogicExpression(leftLogic, rightLogic) + + val columnToCurrentRowValueMap = new util.HashMap[String, ByteArrayComparable]() + columnToCurrentRowValueMap.put("Col1", new ByteArrayComparable(Bytes.toBytes(-4.0f))) + val valueFromQueryValueArray = new Array[Array[Byte]](2) + valueFromQueryValueArray(0) = encoder.encode(FloatType, 15.0f) + valueFromQueryValueArray(1) = encoder.encode(FloatType, -5.0f) + assert(andLogic.execute(columnToCurrentRowValueMap, valueFromQueryValueArray)) + + valueFromQueryValueArray(0) = encoder.encode(FloatType, 10.0f) + valueFromQueryValueArray(1) = encoder.encode(FloatType, -1.0f) + assert(!andLogic.execute(columnToCurrentRowValueMap, valueFromQueryValueArray)) + + valueFromQueryValueArray(0) = encoder.encode(FloatType, -10.0f) + valueFromQueryValueArray(1) = encoder.encode(FloatType, -20.0f) + assert(!andLogic.execute(columnToCurrentRowValueMap, valueFromQueryValueArray)) + + val expressionString = andLogic.toExpressionString + // Note that here 0 and 1 is index, instead of value. + assert(expressionString.equals("( Col1 < 0 AND Col1 > 1 )")) + + val builtExpression = DynamicLogicExpressionBuilder.build(expressionString, encoder) + valueFromQueryValueArray(0) = encoder.encode(FloatType, 15.0f) + valueFromQueryValueArray(1) = encoder.encode(FloatType, -5.0f) + assert(builtExpression.execute(columnToCurrentRowValueMap, valueFromQueryValueArray)) + + valueFromQueryValueArray(0) = encoder.encode(FloatType, 10.0f) + valueFromQueryValueArray(1) = encoder.encode(FloatType, -1.0f) + assert(!builtExpression.execute(columnToCurrentRowValueMap, valueFromQueryValueArray)) + + valueFromQueryValueArray(0) = encoder.encode(FloatType, -10.0f) + valueFromQueryValueArray(1) = encoder.encode(FloatType, -20.0f) + assert(!builtExpression.execute(columnToCurrentRowValueMap, valueFromQueryValueArray)) + } + + test("String Type") { + val leftLogic = new LessThanLogicExpression("Col1", 0) + val rightLogic = new GreaterThanLogicExpression("Col1", 1) + val andLogic = new AndLogicExpression(leftLogic, rightLogic) + + val columnToCurrentRowValueMap = new util.HashMap[String, ByteArrayComparable]() + + columnToCurrentRowValueMap.put("Col1", new ByteArrayComparable(Bytes.toBytes("row005"))) + val valueFromQueryValueArray = new Array[Array[Byte]](2) + valueFromQueryValueArray(0) = encoder.encode(StringType, "row015") + valueFromQueryValueArray(1) = encoder.encode(StringType, "row000") + assert(andLogic.execute(columnToCurrentRowValueMap, valueFromQueryValueArray)) + + valueFromQueryValueArray(0) = encoder.encode(StringType, "row004") + valueFromQueryValueArray(1) = encoder.encode(StringType, "row000") + assert(!andLogic.execute(columnToCurrentRowValueMap, valueFromQueryValueArray)) + + valueFromQueryValueArray(0) = encoder.encode(StringType, "row020") + valueFromQueryValueArray(1) = encoder.encode(StringType, "row010") + assert(!andLogic.execute(columnToCurrentRowValueMap, valueFromQueryValueArray)) + + val expressionString = andLogic.toExpressionString + // Note that here 0 and 1 is index, instead of value. + assert(expressionString.equals("( Col1 < 0 AND Col1 > 1 )")) + + val builtExpression = DynamicLogicExpressionBuilder.build(expressionString, encoder) + valueFromQueryValueArray(0) = encoder.encode(StringType, "row015") + valueFromQueryValueArray(1) = encoder.encode(StringType, "row000") + assert(builtExpression.execute(columnToCurrentRowValueMap, valueFromQueryValueArray)) + + valueFromQueryValueArray(0) = encoder.encode(StringType, "row004") + valueFromQueryValueArray(1) = encoder.encode(StringType, "row000") + assert(!builtExpression.execute(columnToCurrentRowValueMap, valueFromQueryValueArray)) + + valueFromQueryValueArray(0) = encoder.encode(StringType, "row020") + valueFromQueryValueArray(1) = encoder.encode(StringType, "row010") + assert(!builtExpression.execute(columnToCurrentRowValueMap, valueFromQueryValueArray)) } + test("Boolean Type") { + val leftLogic = new LessThanLogicExpression("Col1", 0) + val rightLogic = new GreaterThanLogicExpression("Col1", 1) + + val columnToCurrentRowValueMap = new util.HashMap[String, ByteArrayComparable]() + columnToCurrentRowValueMap.put("Col1", new ByteArrayComparable(Bytes.toBytes(false))) + val valueFromQueryValueArray = new Array[Array[Byte]](2) + valueFromQueryValueArray(0) = encoder.encode(BooleanType, true) + valueFromQueryValueArray(1) = encoder.encode(BooleanType, false) + assert(leftLogic.execute(columnToCurrentRowValueMap, valueFromQueryValueArray)) + assert(!rightLogic.execute(columnToCurrentRowValueMap, valueFromQueryValueArray)) + } } diff --git hbase-spark/src/test/scala/org/apache/hadoop/hbase/spark/PartitionFilterSuite.scala hbase-spark/src/test/scala/org/apache/hadoop/hbase/spark/PartitionFilterSuite.scala new file mode 100644 index 0000000..bd32ff9 --- /dev/null +++ hbase-spark/src/test/scala/org/apache/hadoop/hbase/spark/PartitionFilterSuite.scala @@ -0,0 +1,498 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.hadoop.hbase.spark + +import org.apache.hadoop.hbase.spark.datasources.HBaseSparkConf +import org.apache.hadoop.hbase.{TableName, HBaseTestingUtility} +import org.apache.spark.sql.datasources.hbase.HBaseTableCatalog +import org.apache.spark.sql.{DataFrame, SQLContext} +import org.apache.spark.{SparkConf, SparkContext, Logging} +import org.scalatest.{BeforeAndAfterAll, BeforeAndAfterEach, FunSuite} + +case class FilterRangeRecord( + col0: Integer, + col1: Boolean, + col2: Double, + col3: Float, + col4: Int, + col5: Long, + col6: Short, + col7: String, + col8: Byte) + +object FilterRangeRecord { + def apply(i: Int): FilterRangeRecord = { + FilterRangeRecord(if (i % 2 == 0) i else -i, + i % 2 == 0, + if (i % 2 == 0) i.toDouble else -i.toDouble, + i.toFloat, + if (i % 2 == 0) i else -i, + i.toLong, + i.toShort, + s"String$i extra", + i.toByte) + } +} + +class PartitionFilterSuite extends FunSuite with + BeforeAndAfterEach with BeforeAndAfterAll with Logging { + @transient var sc: SparkContext = null + var TEST_UTIL: HBaseTestingUtility = new HBaseTestingUtility + + var sqlContext: SQLContext = null + var df: DataFrame = null + var srcDf: DataFrame = null + + def withCatalog(cat: String): DataFrame = { + sqlContext + .read + .options(Map(HBaseTableCatalog.tableCatalog -> cat)) + .format("org.apache.hadoop.hbase.spark") + .load() + } + + override def beforeAll() { + + TEST_UTIL.startMiniCluster + val sparkConf = new SparkConf + sparkConf.set(HBaseSparkConf.BLOCK_CACHE_ENABLE, "true") + sparkConf.set(HBaseSparkConf.BATCH_NUM, "100") + sparkConf.set(HBaseSparkConf.CACHE_SIZE, "100") + + sc = new SparkContext("local", "test", sparkConf) + new HBaseContext(sc, TEST_UTIL.getConfiguration) + sqlContext = new SQLContext(sc) + } + + override def afterAll() { + logInfo("shutting down minicluster") + TEST_UTIL.shutdownMiniCluster() + + sc.stop() + } + + override def beforeEach(): Unit = { + DefaultSourceStaticUtils.lastFiveExecutionRules.clear() + } + + val catalog = s"""{ + |"table":{"namespace":"default", "name":"rangeTable"}, + |"rowkey":"key", + |"columns":{ + |"col0":{"cf":"rowkey", "col":"key", "type":"int"}, + |"col1":{"cf":"cf1", "col":"col1", "type":"boolean"}, + |"col2":{"cf":"cf2", "col":"col2", "type":"double"}, + |"col3":{"cf":"cf3", "col":"col3", "type":"float"}, + |"col4":{"cf":"cf4", "col":"col4", "type":"int"}, + |"col5":{"cf":"cf5", "col":"col5", "type":"bigint"}, + |"col6":{"cf":"cf6", "col":"col6", "type":"smallint"}, + |"col7":{"cf":"cf7", "col":"col7", "type":"string"}, + |"col8":{"cf":"cf8", "col":"col8", "type":"tinyint"} + |} + |}""".stripMargin + + test("populate rangeTable") { + val sql = sqlContext + import sql.implicits._ + + val data = (0 until 32).map { i => + FilterRangeRecord(i) + } + + srcDf = sc.parallelize(data).toDF + srcDf.write.options( + Map(HBaseTableCatalog.tableCatalog -> catalog, HBaseTableCatalog.newTable -> "5")) + .format("org.apache.hadoop.hbase.spark") + .save() + } + test("rangeTable full query") { + val df = withCatalog(catalog) + df.show + assert(df.count() === 32) + } + + /** + *expected result: only showing top 20 rows + *+----+ + *|col0| + *+----+ + *| -31| + *| -29| + *| -27| + *| -25| + *| -23| + *| -21| + *| -19| + *| -17| + *| -15| + *| -13| + *| -11| + *| -9| + *| -7| + *| -5| + *| -3| + *| -1| + *+----+ + */ + test("rangeTable rowkey less than 0") { + val sql = sqlContext + import sql.implicits._ + val df = withCatalog(catalog) + val s = df.filter($"col0" < 0).select($"col0") + s.show + assert(s.count() === 16) + val srcS = srcDf.filter($"col0" < 0).select($"col0") + assert(srcS.collect().toSet === s.collect().toSet) + } + + /** + *expected result: only showing top 20 rows + *+----+ + *|col4| + *+----+ + *| -31| + *| -29| + *| -27| + *| -25| + *| -23| + *| -21| + *| -19| + *| -17| + *| -15| + *| -13| + *| -11| + *| -9| + *| -7| + *| -5| + *| -3| + *| -1| + *+----+ + */ + test("rangeTable int col less than 0") { + val sql = sqlContext + import sql.implicits._ + val df = withCatalog(catalog) + val s = df.filter($"col4" < 0).select($"col4") + s.show + assert(s.count() === 16) + val srcS = srcDf.filter($"col4" < 0).select($"col4") + assert(srcS.collect().toSet === s.collect().toSet) + } + + /** + *expected result: only showing top 20 rows + *+-----+ + *| col2| + *+-----+ + *| 0.0| + *| 2.0| + *|-31.0| + *|-29.0| + *|-27.0| + *|-25.0| + *|-23.0| + *|-21.0| + *|-19.0| + *|-17.0| + *|-15.0| + *|-13.0| + *|-11.0| + *| -9.0| + *| -7.0| + *| -5.0| + *| -3.0| + *| -1.0| + *+-----+ + */ + test("rangeTable double col less than 0") { + val sql = sqlContext + import sql.implicits._ + val df = withCatalog(catalog) + val s = df.filter($"col2" < 3.0).select($"col2") + s.show + assert(s.count() === 18) + val srcS = srcDf.filter($"col2" < 3.0).select($"col2") + assert(srcS.collect().toSet === s.collect().toSet) + } + + /** + * expected result: only showing top 20 rows + *+----+ + *|col0| + *+----+ + *| -31| + *| -29| + *| -27| + *| -25| + *| -23| + *| -21| + *| -19| + *| -17| + *| -15| + *| -13| + *| -11| + *+----+ + * + */ + test("rangeTable lessequal than -10") { + val sql = sqlContext + import sql.implicits._ + val df = withCatalog(catalog) + val s = df.filter($"col0" <= -10).select($"col0") + s.show + assert(s.count() === 11) + val srcS = srcDf.filter($"col0" <= -10).select($"col0") + assert(srcS.collect().toSet === s.collect().toSet) + } + + /** + *expected result: only showing top 20 rows + *+----+ + *|col0| + *+----+ + *| -31| + *| -29| + *| -27| + *| -25| + *| -23| + *| -21| + *| -19| + *| -17| + *| -15| + *| -13| + *| -11| + *| -9| + *+----+ + */ + test("rangeTable lessequal than -9") { + val sql = sqlContext + import sql.implicits._ + val df = withCatalog(catalog) + val s = df.filter($"col0" <= -9).select($"col0") + s.show + assert(s.count() === 12) + val srcS = srcDf.filter($"col0" <= -9).select($"col0") + assert(srcS.collect().toSet === s.collect().toSet) + } + + /** + *expected result: only showing top 20 rows + *+----+ + *|col0| + *+----+ + *| 0| + *| 2| + *| 4| + *| 6| + *| 8| + *| 10| + *| 12| + *| 14| + *| 16| + *| 18| + *| 20| + *| 22| + *| 24| + *| 26| + *| 28| + *| 30| + *| -9| + *| -7| + *| -5| + *| -3| + *+----+ + */ + test("rangeTable greaterequal than -9") { + val sql = sqlContext + import sql.implicits._ + val df = withCatalog(catalog) + val s = df.filter($"col0" >= -9).select($"col0") + s.show + assert(s.count() === 21) + val srcS = srcDf.filter($"col0" >= -9).select($"col0") + assert(srcS.collect().toSet === s.collect().toSet) + } + + /** + *expected result: only showing top 20 rows + *+----+ + *|col0| + *+----+ + *| 0| + *| 2| + *| 4| + *| 6| + *| 8| + *| 10| + *| 12| + *| 14| + *| 16| + *| 18| + *| 20| + *| 22| + *| 24| + *| 26| + *| 28| + *| 30| + *+----+ + */ + test("rangeTable greaterequal than 0") { + val sql = sqlContext + import sql.implicits._ + val df = withCatalog(catalog) + val s = df.filter($"col0" >= 0).select($"col0") + s.show + assert(s.count() === 16) + val srcS = srcDf.filter($"col0" >= 0).select($"col0") + assert(srcS.collect().toSet === s.collect().toSet) + } + + /** + *expected result: only showing top 20 rows + *+----+ + *|col0| + *+----+ + *| 12| + *| 14| + *| 16| + *| 18| + *| 20| + *| 22| + *| 24| + *| 26| + *| 28| + *| 30| + *+----+ + */ + test("rangeTable greater than 10") { + val sql = sqlContext + import sql.implicits._ + val df = withCatalog(catalog) + val s = df.filter($"col0" > 10).select($"col0") + s.show + assert(s.count() === 10) + val srcS = srcDf.filter($"col0" > 10).select($"col0") + assert(srcS.collect().toSet === s.collect().toSet) + } + + /** + *expected result: only showing top 20 rows + *+----+ + *|col0| + *+----+ + *| 0| + *| 2| + *| 4| + *| 6| + *| 8| + *| 10| + *| -9| + *| -7| + *| -5| + *| -3| + *| -1| + *+----+ + */ + test("rangeTable and") { + val sql = sqlContext + import sql.implicits._ + val df = withCatalog(catalog) + val s = df.filter($"col0" > -10 && $"col0" <= 10).select($"col0") + s.show + assert(s.count() === 11) + val srcS = srcDf.filter($"col0" > -10 && $"col0" <= 10).select($"col0") + assert(srcS.collect().toSet === s.collect().toSet) + } + + /** + *expected result: only showing top 20 rows + *+----+ + *|col0| + *+----+ + *| 12| + *| 14| + *| 16| + *| 18| + *| 20| + *| 22| + *| 24| + *| 26| + *| 28| + *| 30| + *| -31| + *| -29| + *| -27| + *| -25| + *| -23| + *| -21| + *| -19| + *| -17| + *| -15| + *| -13| + *+----+ + */ + + test("or") { + val sql = sqlContext + import sql.implicits._ + val df = withCatalog(catalog) + val s = df.filter($"col0" <= -10 || $"col0" > 10).select($"col0") + s.show + assert(s.count() === 21) + val srcS = srcDf.filter($"col0" <= -10 || $"col0" > 10).select($"col0") + assert(srcS.collect().toSet === s.collect().toSet) + } + + /** + *expected result: only showing top 20 rows + *+----+ + *|col0| + *+----+ + *| 0| + *| 2| + *| 4| + *| 6| + *| 8| + *| 10| + *| 12| + *| 14| + *| 16| + *| 18| + *| 20| + *| 22| + *| 24| + *| 26| + *| 28| + *| 30| + *| -31| + *| -29| + *| -27| + *| -25| + *+----+ + */ + test("rangeTable all") { + val sql = sqlContext + import sql.implicits._ + val df = withCatalog(catalog) + val s = df.filter($"col0" >= -100).select($"col0") + s.show + assert(s.count() === 32) + val srcS = srcDf.filter($"col0" >= -100).select($"col0") + assert(srcS.collect().toSet === s.collect().toSet) + } +}