diff --git hbase-spark/src/main/scala/org/apache/hadoop/hbase/spark/DefaultSource.scala hbase-spark/src/main/scala/org/apache/hadoop/hbase/spark/DefaultSource.scala index 7970816..32eb312 100644 --- hbase-spark/src/main/scala/org/apache/hadoop/hbase/spark/DefaultSource.scala +++ hbase-spark/src/main/scala/org/apache/hadoop/hbase/spark/DefaultSource.scala @@ -23,6 +23,7 @@ import java.util.concurrent.ConcurrentLinkedQueue import org.apache.hadoop.hbase.client._ import org.apache.hadoop.hbase.io.ImmutableBytesWritable import org.apache.hadoop.hbase.mapred.TableOutputFormat +import org.apache.hadoop.hbase.spark.datasources.BoundRange import org.apache.hadoop.hbase.spark.datasources.HBaseSparkConf import org.apache.hadoop.hbase.spark.datasources.HBaseTableScanRDD import org.apache.hadoop.hbase.spark.datasources.SerializableConfiguration @@ -419,14 +420,21 @@ case class HBaseRelation ( val field = catalog.getField(attr) if (field != null) { if (field.isRowKey) { - parentRowKeyFilter.mergeIntersect(new RowKeyFilter(null, - new ScanRange(DefaultSourceStaticUtils.getByteValue(field, - value.toString), false, - new Array[Byte](0), true))) + val b = BoundRange(value) + var inc = false + b.map(_.less.map { x => + val r = new RowKeyFilter(null, + new ScanRange(x.upper, inc, x.low, true) + ) + inc = true + r + }).map { x => + x.reduce { (i, j) => + i.mergeUnion(j) + } + }.map(parentRowKeyFilter.mergeIntersect(_)) } - val byteValue = - DefaultSourceStaticUtils.getByteValue(catalog.getField(attr), - value.toString) + val byteValue = FilterOps.encode(field.dt, value) valueArray += byteValue } new LessThanLogicExpression(attr, valueArray.length - 1) @@ -434,13 +442,20 @@ case class HBaseRelation ( val field = catalog.getField(attr) if (field != null) { if (field.isRowKey) { - parentRowKeyFilter.mergeIntersect(new RowKeyFilter(null, - new ScanRange(null, true, DefaultSourceStaticUtils.getByteValue(field, - value.toString), false))) + val b = BoundRange(value) + var inc = false + b.map(_.greater.map{x => + val r = new RowKeyFilter(null, + new ScanRange(x.upper, true, x.low, inc)) + inc = true + r + }).map { x => + x.reduce { (i, j) => + i.mergeUnion(j) + } + }.map(parentRowKeyFilter.mergeIntersect(_)) } - val byteValue = - DefaultSourceStaticUtils.getByteValue(field, - value.toString) + val byteValue = FilterOps.encode(field.dt, value) valueArray += byteValue } new GreaterThanLogicExpression(attr, valueArray.length - 1) @@ -448,14 +463,17 @@ case class HBaseRelation ( val field = catalog.getField(attr) if (field != null) { if (field.isRowKey) { - parentRowKeyFilter.mergeIntersect(new RowKeyFilter(null, - new ScanRange(DefaultSourceStaticUtils.getByteValue(field, - value.toString), true, - new Array[Byte](0), true))) + val b = BoundRange(value) + b.map(_.less.map(x => + new RowKeyFilter(null, + new ScanRange(x.upper, true, x.low, true)))) + .map { x => + x.reduce{ (i, j) => + i.mergeUnion(j) + } + }.map(parentRowKeyFilter.mergeIntersect(_)) } - val byteValue = - DefaultSourceStaticUtils.getByteValue(catalog.getField(attr), - value.toString) + val byteValue = FilterOps.encode(field.dt, value) valueArray += byteValue } new LessThanOrEqualLogicExpression(attr, valueArray.length - 1) @@ -463,15 +481,18 @@ case class HBaseRelation ( val field = catalog.getField(attr) if (field != null) { if (field.isRowKey) { - parentRowKeyFilter.mergeIntersect(new RowKeyFilter(null, - new ScanRange(null, true, DefaultSourceStaticUtils.getByteValue(field, - value.toString), true))) + val b = BoundRange(value) + b.map(_.greater.map(x => + new RowKeyFilter(null, + new ScanRange(x.upper, true, x.low, true)))) + .map { x => + x.reduce { (i, j) => + i.mergeUnion(j) + } + }.map(parentRowKeyFilter.mergeIntersect(_)) } - val byteValue = - DefaultSourceStaticUtils.getByteValue(catalog.getField(attr), - value.toString) + val byteValue = FilterOps.encode(field.dt, value) valueArray += byteValue - } new GreaterThanOrEqualLogicExpression(attr, valueArray.length - 1) case Or(left, right) => @@ -595,9 +616,14 @@ class ScanRange(var upperBound:Array[Byte], var isUpperBoundEqualTo:Boolean, //Then see if leftRange goes to null or if leftRange.upperBound // upper is greater or equals to rightRange.lowerBound - if (leftRange.upperBound == null || - Bytes.compareTo(leftRange.upperBound, rightRange.lowerBound) >= 0) { - new ScanRange(leftRange.upperBound, leftRange.isUpperBoundEqualTo, rightRange.lowerBound, rightRange.isLowerBoundEqualTo) + if (compareRange(leftRange.upperBound, rightRange.lowerBound) >= 0) { + if (compareRange(leftRange.upperBound, rightRange.upperBound) >=0) { + new ScanRange(rightRange.upperBound, rightRange.isUpperBoundEqualTo, + rightRange.lowerBound, rightRange.isLowerBoundEqualTo) + } else { + new ScanRange(leftRange.upperBound, leftRange.isUpperBoundEqualTo, + rightRange.lowerBound, rightRange.isLowerBoundEqualTo) + } } else { null } @@ -1041,7 +1067,7 @@ class RowKeyFilter (currentPoint:Array[Byte] = null, * * @param other Filter to merge */ - def mergeUnion(other:RowKeyFilter): Unit = { + def mergeUnion(other:RowKeyFilter): RowKeyFilter = { other.points.foreach( p => points += p) other.ranges.foreach( otherR => { @@ -1053,6 +1079,7 @@ class RowKeyFilter (currentPoint:Array[Byte] = null, }} if (!doesOverLap) ranges.+=(otherR) }) + this } /** @@ -1061,7 +1088,7 @@ class RowKeyFilter (currentPoint:Array[Byte] = null, * * @param other Filter to merge */ - def mergeIntersect(other:RowKeyFilter): Unit = { + def mergeIntersect(other:RowKeyFilter): RowKeyFilter = { val survivingPoints = new mutable.MutableList[Array[Byte]]() val didntSurviveFirstPassPoints = new mutable.MutableList[Array[Byte]]() if (points == null || points.length == 0) { @@ -1107,6 +1134,7 @@ class RowKeyFilter (currentPoint:Array[Byte] = null, } points = survivingPoints ranges = survivingRanges + this } override def toString:String = { diff --git hbase-spark/src/main/scala/org/apache/hadoop/hbase/spark/DynamicLogicExpression.scala hbase-spark/src/main/scala/org/apache/hadoop/hbase/spark/DynamicLogicExpression.scala index fa61860..a211dff 100644 --- hbase-spark/src/main/scala/org/apache/hadoop/hbase/spark/DynamicLogicExpression.scala +++ hbase-spark/src/main/scala/org/apache/hadoop/hbase/spark/DynamicLogicExpression.scala @@ -19,7 +19,129 @@ package org.apache.hadoop.hbase.spark import java.util +import org.apache.hadoop.hbase.spark.FilterOps.FilterOps import org.apache.hadoop.hbase.util.Bytes +import org.apache.spark.sql.datasources.hbase.{Field, Utils} +import org.apache.spark.sql.types._ + + +object FilterOps extends Enumeration { + type FilterOps = Value + val Greater, GreaterEqual, Less, LessEqual, Equal, Unknown = Value + var code = 0 + def nextCode: Byte = { + code += 1 + (code - 1).asInstanceOf[Byte] + } + val BooleanEnc = nextCode + val ShortEnc = nextCode + val IntEnc = nextCode + val LongEnc = nextCode + val FloatEnc = nextCode + val DoubleEnc = nextCode + val StringEnc = nextCode + val BinaryEnc = nextCode + val TimestampEnc = nextCode + val UnknownEnc = nextCode + + def compare(c: Int, ops: FilterOps): Boolean = { + ops match { + case Greater => c > 0 + case GreaterEqual => c >= 0 + case Less => c < 0 + case LessEqual => c <= 0 + } + } + + def encode(dt: DataType, + value: Any): Array[Byte] = { + dt match { + case BooleanType => + val result = new Array[Byte](Bytes.SIZEOF_BOOLEAN + 1) + result(0) = BooleanEnc + value.asInstanceOf[Boolean] match { + case true => result(1) = -1: Byte + case false => result(1) = 0: Byte + } + result + case ShortType => + val result = new Array[Byte](Bytes.SIZEOF_SHORT + 1) + result(0) = ShortEnc + Bytes.putShort(result, 1, value.asInstanceOf[Short]) + result + case IntegerType => + val result = new Array[Byte](Bytes.SIZEOF_INT + 1) + result(0) = IntEnc + Bytes.putInt(result, 1, value.asInstanceOf[Int]) + result + case LongType|TimestampType => + val result = new Array[Byte](Bytes.SIZEOF_LONG + 1) + result(0) = LongEnc + Bytes.putLong(result, 1, value.asInstanceOf[Long]) + result + case FloatType => + val result = new Array[Byte](Bytes.SIZEOF_FLOAT + 1) + result(0) = FloatEnc + Bytes.putFloat(result, 1, value.asInstanceOf[Float]) + result + case DoubleType => + val result = new Array[Byte](Bytes.SIZEOF_DOUBLE + 1) + result(0) = DoubleEnc + Bytes.putDouble(result, 1, value.asInstanceOf[Double]) + result + case BinaryType => + val v = value.asInstanceOf[Array[Bytes]] + val result = new Array[Byte](v.length + 1) + result(0) = BinaryEnc + System.arraycopy(v, 0, result, 1, v.length) + result + case StringType => + val bytes = Bytes.toBytes(value.asInstanceOf[String]) + val result = new Array[Byte](bytes.length + 1) + result(0) = StringEnc + System.arraycopy(bytes, 0, result, 1, bytes.length) + result + case _ => + val bytes = Bytes.toBytes(value.toString) + val result = new Array[Byte](bytes.length + 1) + result(0) = UnknownEnc + System.arraycopy(bytes, 0, result, 1, bytes.length) + result + } + } + + def filter(input: Array[Byte], offset1: Int, length1: Int, + filterBytes: Array[Byte], offset2: Int, length2: Int, + ops: FilterOps): Boolean = { + filterBytes(offset2) match { + case ShortEnc => + val in = Bytes.toShort(input, offset1) + val value = Bytes.toShort(filterBytes, offset2 + 1) + compare(in.compareTo(value), ops) + case FilterOps.IntEnc => + val in = Bytes.toInt(input, offset1) + val value = Bytes.toInt(filterBytes, offset2 + 1) + compare(in.compareTo(value), ops) + case FilterOps.LongEnc|FilterOps.TimestampEnc => + val in = Bytes.toInt(input, offset1) + val value = Bytes.toInt(filterBytes, offset2 + 1) + compare(in.compareTo(value), ops) + case FilterOps.FloatEnc => + val in = Bytes.toFloat(input, offset1) + val value = Bytes.toFloat(filterBytes, offset2 + 1) + compare(in.compareTo(value), ops) + case FilterOps.DoubleEnc => + val in = Bytes.toDouble(input, offset1) + val value = Bytes.toDouble(filterBytes, offset2 + 1) + compare(in.compareTo(value), ops) + case _ => + // for String, Byte, Binary, Boolean and other types + // we can use the order of byte array directly. + compare( + Bytes.compareTo(input, offset1, length1, filterBytes, offset2 + 1, length2 - 1), ops) + } + } +} /** * Dynamic logic for SQL push down logic there is an instance for most @@ -38,9 +160,26 @@ trait DynamicLogicExpression { appendToExpression(strBuilder) strBuilder.toString() } + def filterOps: FilterOps = FilterOps.Unknown + def appendToExpression(strBuilder:StringBuilder) } +trait CompareTrait { + self: DynamicLogicExpression => + def columnName: String + def valueFromQueryIndex: Int + def execute(columnToCurrentRowValueMap: + util.HashMap[String, ByteArrayComparable], + valueFromQueryValueArray:Array[Array[Byte]]): Boolean = { + val currentRowValue = columnToCurrentRowValueMap.get(columnName) + val valueFromQuery = valueFromQueryValueArray(valueFromQueryIndex) + currentRowValue != null && + FilterOps.filter(currentRowValue.bytes, currentRowValue.offset, currentRowValue.length, + valueFromQuery, 0, valueFromQuery.length, filterOps) + } +} + class AndLogicExpression (val leftExpression:DynamicLogicExpression, val rightExpression:DynamicLogicExpression) extends DynamicLogicExpression{ @@ -113,59 +252,28 @@ class IsNullLogicExpression (val columnName:String, } } -class GreaterThanLogicExpression (val columnName:String, - val valueFromQueryIndex:Int) - extends DynamicLogicExpression{ - override def execute(columnToCurrentRowValueMap: - util.HashMap[String, ByteArrayComparable], - valueFromQueryValueArray:Array[Array[Byte]]): Boolean = { - val currentRowValue = columnToCurrentRowValueMap.get(columnName) - val valueFromQuery = valueFromQueryValueArray(valueFromQueryIndex) - - currentRowValue != null && - Bytes.compareTo(currentRowValue.bytes, - currentRowValue.offset, currentRowValue.length, valueFromQuery, - 0, valueFromQuery.length) > 0 - } +class GreaterThanLogicExpression (override val columnName:String, + override val valueFromQueryIndex:Int) + extends DynamicLogicExpression with CompareTrait{ + override val filterOps = FilterOps.Greater override def appendToExpression(strBuilder: StringBuilder): Unit = { strBuilder.append(columnName + " > " + valueFromQueryIndex) } } -class GreaterThanOrEqualLogicExpression (val columnName:String, - val valueFromQueryIndex:Int) - extends DynamicLogicExpression{ - override def execute(columnToCurrentRowValueMap: - util.HashMap[String, ByteArrayComparable], - valueFromQueryValueArray:Array[Array[Byte]]): Boolean = { - val currentRowValue = columnToCurrentRowValueMap.get(columnName) - val valueFromQuery = valueFromQueryValueArray(valueFromQueryIndex) - - currentRowValue != null && - Bytes.compareTo(currentRowValue.bytes, - currentRowValue.offset, currentRowValue.length, valueFromQuery, - 0, valueFromQuery.length) >= 0 - } +class GreaterThanOrEqualLogicExpression (override val columnName:String, + override val valueFromQueryIndex:Int) + extends DynamicLogicExpression with CompareTrait{ + override val filterOps = FilterOps.GreaterEqual override def appendToExpression(strBuilder: StringBuilder): Unit = { strBuilder.append(columnName + " >= " + valueFromQueryIndex) } } -class LessThanLogicExpression (val columnName:String, - val valueFromQueryIndex:Int) - extends DynamicLogicExpression{ - override def execute(columnToCurrentRowValueMap: - util.HashMap[String, ByteArrayComparable], - valueFromQueryValueArray:Array[Array[Byte]]): Boolean = { - val currentRowValue = columnToCurrentRowValueMap.get(columnName) - val valueFromQuery = valueFromQueryValueArray(valueFromQueryIndex) - - currentRowValue != null && - Bytes.compareTo(currentRowValue.bytes, - currentRowValue.offset, currentRowValue.length, valueFromQuery, - 0, valueFromQuery.length) < 0 - } - +class LessThanLogicExpression (override val columnName:String, + override val valueFromQueryIndex:Int) + extends DynamicLogicExpression with CompareTrait { + override val filterOps = FilterOps.Less override def appendToExpression(strBuilder: StringBuilder): Unit = { strBuilder.append(columnName + " < " + valueFromQueryIndex) } @@ -173,19 +281,8 @@ class LessThanLogicExpression (val columnName:String, class LessThanOrEqualLogicExpression (val columnName:String, val valueFromQueryIndex:Int) - extends DynamicLogicExpression{ - override def execute(columnToCurrentRowValueMap: - util.HashMap[String, ByteArrayComparable], - valueFromQueryValueArray:Array[Array[Byte]]): Boolean = { - val currentRowValue = columnToCurrentRowValueMap.get(columnName) - val valueFromQuery = valueFromQueryValueArray(valueFromQueryIndex) - - currentRowValue != null && - Bytes.compareTo(currentRowValue.bytes, - currentRowValue.offset, currentRowValue.length, valueFromQuery, - 0, valueFromQuery.length) <= 0 - } - + extends DynamicLogicExpression with CompareTrait{ + override val filterOps = FilterOps.LessEqual override def appendToExpression(strBuilder: StringBuilder): Unit = { strBuilder.append(columnName + " <= " + valueFromQueryIndex) } @@ -197,7 +294,7 @@ class PassThroughLogicExpression() extends DynamicLogicExpression { valueFromQueryValueArray: Array[Array[Byte]]): Boolean = true override def appendToExpression(strBuilder: StringBuilder): Unit = { - strBuilder.append("Pass") + strBuilder.append("dummy Pass -1") } } @@ -245,7 +342,7 @@ object DynamicLogicExpressionBuilder { } else if (command.equals("isNotNull")) { (new IsNullLogicExpression(expressionArray(offSet), true), offSet + 2) } else if (command.equals("Pass")) { - (new PassThroughLogicExpression, offSet + 2) + (new PassThroughLogicExpression, offSet + 3) } else { throw new Throwable("Unknown logic command:" + command) } diff --git hbase-spark/src/main/scala/org/apache/hadoop/hbase/spark/datasources/BoundRange.scala hbase-spark/src/main/scala/org/apache/hadoop/hbase/spark/datasources/BoundRange.scala new file mode 100644 index 0000000..3152a64 --- /dev/null +++ hbase-spark/src/main/scala/org/apache/hadoop/hbase/spark/datasources/BoundRange.scala @@ -0,0 +1,134 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.hadoop.hbase.spark.datasources + +import org.apache.hadoop.hbase.util.Bytes +import org.apache.hadoop.hbase.spark.hbase._ +import org.apache.spark.Logging +import org.apache.spark.unsafe.types.UTF8String + +// Data type for range whose size is known. +// lower bound and upperbound for each range. +// If data order is the same as byte order, then left = mid = right. +// For the data type whose order is not the same as byte order, left != mid != right +// In this case, left is max, right is min and mid is the byte of the value. +// By this way, the scan will cover the whole range and will not miss any data. +// Typically, mid is used only in Equal in which case, the order does not matter. +case class BoundRange( + low: Array[Byte], + upper: Array[Byte]) + +// The range in less and greater have to be lexi ordered. +case class BoundRanges(less: Array[BoundRange], greater: Array[BoundRange], value: Array[Byte]) + +object BoundRange extends Logging{ + def apply(in: Any): Option[BoundRanges] = in match { + // For short, integer, and long, the order of number is consistent with byte array order + // regardless of its sign. But the negative number is larger than positive number in byte array. + case a: Integer => + val b = Bytes.toBytes(a) + if (a >= 0) { + logDebug(s"range is 0 to $a and ${Integer.MIN_VALUE} to -1") + Some(BoundRanges( + Array(BoundRange(Bytes.toBytes(0: Int), b), + BoundRange(Bytes.toBytes(Integer.MIN_VALUE), Bytes.toBytes(-1: Int))), + Array(BoundRange(b, Bytes.toBytes(Integer.MAX_VALUE))), b)) + } else { + Some(BoundRanges( + Array(BoundRange(Bytes.toBytes(Integer.MIN_VALUE), b)), + Array(BoundRange(b, Bytes.toBytes(-1: Integer)), + BoundRange(Bytes.toBytes(0: Int), Bytes.toBytes(Integer.MAX_VALUE))), b)) + } + case a: Long => + val b = Bytes.toBytes(a) + if (a >= 0) { + Some(BoundRanges( + Array(BoundRange(Bytes.toBytes(0: Long), b), + BoundRange(Bytes.toBytes(Long.MinValue), Bytes.toBytes(-1: Long))), + Array(BoundRange(b, Bytes.toBytes(Long.MaxValue))), b)) + } else { + Some(BoundRanges( + Array(BoundRange(Bytes.toBytes(Long.MinValue), b)), + Array(BoundRange(b, Bytes.toBytes(-1: Long)), + BoundRange(Bytes.toBytes(0: Long), Bytes.toBytes(Long.MaxValue))), b)) + } + case a: Short => + val b = Bytes.toBytes(a) + if (a >= 0) { + Some(BoundRanges( + Array(BoundRange(Bytes.toBytes(0: Short), b), + BoundRange(Bytes.toBytes(Short.MinValue), Bytes.toBytes(-1: Short))), + Array(BoundRange(b, Bytes.toBytes(Short.MaxValue))), b)) + } else { + Some(BoundRanges( + Array(BoundRange(Bytes.toBytes(Short.MinValue), b)), + Array(BoundRange(b, Bytes.toBytes(-1: Short)), + BoundRange(Bytes.toBytes(0: Short), Bytes.toBytes(Short.MaxValue))), b)) + } + // For both double and float, the order of positive number is consistent + // with byte array order. But the order of negative number is the reverse + // order of byte array. Please refer to IEEE-754 and + // https://en.wikipedia.org/wiki/Single-precision_floating-point_format + case a: Double => + val b = Bytes.toBytes(a) + if (a >= 0.0f) { + Some(BoundRanges( + Array(BoundRange(Bytes.toBytes(0.0d), b), + BoundRange(Bytes.toBytes(-0.0d), Bytes.toBytes(Double.MinValue))), + Array(BoundRange(b, Bytes.toBytes(Double.MaxValue))), b)) + } else { + Some(BoundRanges( + Array(BoundRange(b, Bytes.toBytes(Double.MinValue))), + Array(BoundRange(Bytes.toBytes(-0.0d), b), + BoundRange(Bytes.toBytes(0.0d), Bytes.toBytes(Double.MaxValue))), b)) + } + case a: Float => + val b = Bytes.toBytes(a) + if (a >= 0.0f) { + Some(BoundRanges( + Array(BoundRange(Bytes.toBytes(0.0f), b), + BoundRange(Bytes.toBytes(-0.0f), Bytes.toBytes(Float.MinValue))), + Array(BoundRange(b, Bytes.toBytes(Float.MaxValue))), b)) + } else { + Some(BoundRanges( + Array(BoundRange(b, Bytes.toBytes(Float.MinValue))), + Array(BoundRange(Bytes.toBytes(-0.0f), b), + BoundRange(Bytes.toBytes(0.0f), Bytes.toBytes(Float.MaxValue))), b)) + } + case a: Array[Byte] => + Some(BoundRanges( + Array(BoundRange(bytesMin, a)), + Array(BoundRange(a, bytesMax)), a)) + case a: Byte => + val b = Array(a) + Some(BoundRanges( + Array(BoundRange(bytesMin, b)), + Array(BoundRange(b, bytesMax)), b)) + case a: String => + val b = Bytes.toBytes(a) + Some(BoundRanges( + Array(BoundRange(bytesMin, b)), + Array(BoundRange(b, bytesMax)), b)) + case a: UTF8String => + val b = a.getBytes + Some(BoundRanges( + Array(BoundRange(bytesMin, b)), + Array(BoundRange(b, bytesMax)), b)) + case _ => None + } +} diff --git hbase-spark/src/main/scala/org/apache/hadoop/hbase/spark/datasources/package.scala hbase-spark/src/main/scala/org/apache/hadoop/hbase/spark/datasources/package.scala index 4ff0413..ce7b55a 100644 --- hbase-spark/src/main/scala/org/apache/hadoop/hbase/spark/datasources/package.scala +++ hbase-spark/src/main/scala/org/apache/hadoop/hbase/spark/datasources/package.scala @@ -23,6 +23,8 @@ import scala.math.Ordering package object hbase { type HBaseType = Array[Byte] + def bytesMin = new Array[Byte](0) + def bytesMax = null val ByteMax = -1.asInstanceOf[Byte] val ByteMin = 0.asInstanceOf[Byte] val ord: Ordering[HBaseType] = new Ordering[HBaseType] { diff --git hbase-spark/src/main/scala/org/apache/spark/sql/datasources/hbase/HBaseTableCatalog.scala hbase-spark/src/main/scala/org/apache/spark/sql/datasources/hbase/HBaseTableCatalog.scala index 831c7de..c2d611f 100644 --- hbase-spark/src/main/scala/org/apache/spark/sql/datasources/hbase/HBaseTableCatalog.scala +++ hbase-spark/src/main/scala/org/apache/spark/sql/datasources/hbase/HBaseTableCatalog.scala @@ -156,7 +156,7 @@ case class HBaseTableCatalog( def get(key: String) = params.get(key) // Setup the start and length for each dimension of row key at runtime. - def dynSetupRowKey(rowKey: HBaseType) { + def dynSetupRowKey(rowKey: Array[Byte]) { logDebug(s"length: ${rowKey.length}") if(row.varLength) { var start = 0 diff --git hbase-spark/src/main/scala/org/apache/spark/sql/datasources/hbase/Utils.scala hbase-spark/src/main/scala/org/apache/spark/sql/datasources/hbase/Utils.scala index 73d054d..9f49f27 100644 --- hbase-spark/src/main/scala/org/apache/spark/sql/datasources/hbase/Utils.scala +++ hbase-spark/src/main/scala/org/apache/spark/sql/datasources/hbase/Utils.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.datasources.hbase import org.apache.hadoop.hbase.spark.AvroSerdes +import org.apache.hadoop.hbase.spark.FilterOps.FilterOps import org.apache.hadoop.hbase.util.Bytes import org.apache.spark.sql.execution.SparkSqlSerializer import org.apache.spark.sql.types._ diff --git hbase-spark/src/test/scala/org/apache/hadoop/hbase/spark/DefaultSourceSuite.scala hbase-spark/src/test/scala/org/apache/hadoop/hbase/spark/DefaultSourceSuite.scala index 500967d..73adfb9 100644 --- hbase-spark/src/test/scala/org/apache/hadoop/hbase/spark/DefaultSourceSuite.scala +++ hbase-spark/src/test/scala/org/apache/hadoop/hbase/spark/DefaultSourceSuite.scala @@ -322,14 +322,7 @@ BeforeAndAfterEach with BeforeAndAfterAll with Logging { equals("( KEY_FIELD < 0 OR KEY_FIELD > 1 )")) assert(executionRules.rowKeyFilter.points.size == 0) - assert(executionRules.rowKeyFilter.ranges.size == 1) - - val scanRange1 = executionRules.rowKeyFilter.ranges.get(0).get - assert(Bytes.equals(scanRange1.lowerBound,Bytes.toBytes(""))) - assert(scanRange1.upperBound == null) - assert(scanRange1.isLowerBoundEqualTo) - assert(scanRange1.isUpperBoundEqualTo) - + assert(executionRules.rowKeyFilter.ranges.size == 2) assert(results.length == 5) } @@ -351,18 +344,14 @@ BeforeAndAfterEach with BeforeAndAfterAll with Logging { assert(executionRules.rowKeyFilter.points.size == 0) - assert(executionRules.rowKeyFilter.ranges.size == 2) + assert(executionRules.rowKeyFilter.ranges.size == 3) val scanRange1 = executionRules.rowKeyFilter.ranges.get(0).get - assert(Bytes.equals(scanRange1.lowerBound,Bytes.toBytes(""))) assert(Bytes.equals(scanRange1.upperBound, Bytes.toBytes(2))) assert(scanRange1.isLowerBoundEqualTo) assert(!scanRange1.isUpperBoundEqualTo) val scanRange2 = executionRules.rowKeyFilter.ranges.get(1).get - assert(Bytes.equals(scanRange2.lowerBound, Bytes.toBytes(4))) - assert(scanRange2.upperBound == null) - assert(!scanRange2.isLowerBoundEqualTo) assert(scanRange2.isUpperBoundEqualTo) assert(results.length == 2) diff --git hbase-spark/src/test/scala/org/apache/hadoop/hbase/spark/DynamicLogicExpressionSuite.scala hbase-spark/src/test/scala/org/apache/hadoop/hbase/spark/DynamicLogicExpressionSuite.scala index 3140ebd..745228a 100644 --- hbase-spark/src/test/scala/org/apache/hadoop/hbase/spark/DynamicLogicExpressionSuite.scala +++ hbase-spark/src/test/scala/org/apache/hadoop/hbase/spark/DynamicLogicExpressionSuite.scala @@ -21,6 +21,7 @@ import java.util import org.apache.hadoop.hbase.util.Bytes import org.apache.spark.Logging +import org.apache.spark.sql.types.IntegerType import org.scalatest.{BeforeAndAfterAll, BeforeAndAfterEach, FunSuite} class DynamicLogicExpressionSuite extends FunSuite with @@ -35,16 +36,16 @@ BeforeAndAfterEach with BeforeAndAfterAll with Logging { columnToCurrentRowValueMap.put("Col1", new ByteArrayComparable(Bytes.toBytes(10))) val valueFromQueryValueArray = new Array[Array[Byte]](2) - valueFromQueryValueArray(0) = Bytes.toBytes(15) - valueFromQueryValueArray(1) = Bytes.toBytes(5) + valueFromQueryValueArray(0) = FilterOps.encode(IntegerType, 15) + valueFromQueryValueArray(1) = FilterOps.encode(IntegerType, 5) assert(andLogic.execute(columnToCurrentRowValueMap, valueFromQueryValueArray)) - valueFromQueryValueArray(0) = Bytes.toBytes(10) - valueFromQueryValueArray(1) = Bytes.toBytes(5) + valueFromQueryValueArray(0) = FilterOps.encode(IntegerType, 10) + valueFromQueryValueArray(1) = FilterOps.encode(IntegerType, 5) assert(!andLogic.execute(columnToCurrentRowValueMap, valueFromQueryValueArray)) - valueFromQueryValueArray(0) = Bytes.toBytes(15) - valueFromQueryValueArray(1) = Bytes.toBytes(10) + valueFromQueryValueArray(0) = FilterOps.encode(IntegerType, 15) + valueFromQueryValueArray(1) = FilterOps.encode(IntegerType, 10) assert(!andLogic.execute(columnToCurrentRowValueMap, valueFromQueryValueArray)) val expressionString = andLogic.toExpressionString @@ -52,16 +53,16 @@ BeforeAndAfterEach with BeforeAndAfterAll with Logging { assert(expressionString.equals("( Col1 < 0 AND Col1 > 1 )")) val builtExpression = DynamicLogicExpressionBuilder.build(expressionString) - valueFromQueryValueArray(0) = Bytes.toBytes(15) - valueFromQueryValueArray(1) = Bytes.toBytes(5) + valueFromQueryValueArray(0) = FilterOps.encode(IntegerType, 15) + valueFromQueryValueArray(1) = FilterOps.encode(IntegerType, 5) assert(builtExpression.execute(columnToCurrentRowValueMap, valueFromQueryValueArray)) - valueFromQueryValueArray(0) = Bytes.toBytes(10) - valueFromQueryValueArray(1) = Bytes.toBytes(5) + valueFromQueryValueArray(0) = FilterOps.encode(IntegerType, 10) + valueFromQueryValueArray(1) = FilterOps.encode(IntegerType, 5) assert(!builtExpression.execute(columnToCurrentRowValueMap, valueFromQueryValueArray)) - valueFromQueryValueArray(0) = Bytes.toBytes(15) - valueFromQueryValueArray(1) = Bytes.toBytes(10) + valueFromQueryValueArray(0) = FilterOps.encode(IntegerType, 15) + valueFromQueryValueArray(1) = FilterOps.encode(IntegerType, 10) assert(!builtExpression.execute(columnToCurrentRowValueMap, valueFromQueryValueArray)) } @@ -75,20 +76,20 @@ BeforeAndAfterEach with BeforeAndAfterAll with Logging { columnToCurrentRowValueMap.put("Col1", new ByteArrayComparable(Bytes.toBytes(10))) val valueFromQueryValueArray = new Array[Array[Byte]](2) - valueFromQueryValueArray(0) = Bytes.toBytes(15) - valueFromQueryValueArray(1) = Bytes.toBytes(5) + valueFromQueryValueArray(0) = FilterOps.encode(IntegerType, 15) + valueFromQueryValueArray(1) = FilterOps.encode(IntegerType, 5) assert(OrLogic.execute(columnToCurrentRowValueMap, valueFromQueryValueArray)) - valueFromQueryValueArray(0) = Bytes.toBytes(10) - valueFromQueryValueArray(1) = Bytes.toBytes(5) + valueFromQueryValueArray(0) = FilterOps.encode(IntegerType, 10) + valueFromQueryValueArray(1) = FilterOps.encode(IntegerType, 5) assert(OrLogic.execute(columnToCurrentRowValueMap, valueFromQueryValueArray)) - valueFromQueryValueArray(0) = Bytes.toBytes(15) - valueFromQueryValueArray(1) = Bytes.toBytes(10) + valueFromQueryValueArray(0) = FilterOps.encode(IntegerType, 15) + valueFromQueryValueArray(1) = FilterOps.encode(IntegerType, 10) assert(OrLogic.execute(columnToCurrentRowValueMap, valueFromQueryValueArray)) - valueFromQueryValueArray(0) = Bytes.toBytes(10) - valueFromQueryValueArray(1) = Bytes.toBytes(10) + valueFromQueryValueArray(0) = FilterOps.encode(IntegerType, 10) + valueFromQueryValueArray(1) = FilterOps.encode(IntegerType, 10) assert(!OrLogic.execute(columnToCurrentRowValueMap, valueFromQueryValueArray)) val expressionString = OrLogic.toExpressionString @@ -96,20 +97,20 @@ BeforeAndAfterEach with BeforeAndAfterAll with Logging { assert(expressionString.equals("( Col1 < 0 OR Col1 > 1 )")) val builtExpression = DynamicLogicExpressionBuilder.build(expressionString) - valueFromQueryValueArray(0) = Bytes.toBytes(15) - valueFromQueryValueArray(1) = Bytes.toBytes(5) + valueFromQueryValueArray(0) = FilterOps.encode(IntegerType, 15) + valueFromQueryValueArray(1) = FilterOps.encode(IntegerType, 5) assert(builtExpression.execute(columnToCurrentRowValueMap, valueFromQueryValueArray)) - valueFromQueryValueArray(0) = Bytes.toBytes(10) - valueFromQueryValueArray(1) = Bytes.toBytes(5) + valueFromQueryValueArray(0) = FilterOps.encode(IntegerType, 10) + valueFromQueryValueArray(1) = FilterOps.encode(IntegerType, 5) assert(builtExpression.execute(columnToCurrentRowValueMap, valueFromQueryValueArray)) - valueFromQueryValueArray(0) = Bytes.toBytes(15) - valueFromQueryValueArray(1) = Bytes.toBytes(10) + valueFromQueryValueArray(0) = FilterOps.encode(IntegerType, 15) + valueFromQueryValueArray(1) = FilterOps.encode(IntegerType, 10) assert(builtExpression.execute(columnToCurrentRowValueMap, valueFromQueryValueArray)) - valueFromQueryValueArray(0) = Bytes.toBytes(10) - valueFromQueryValueArray(1) = Bytes.toBytes(10) + valueFromQueryValueArray(0) = FilterOps.encode(IntegerType, 10) + valueFromQueryValueArray(1) = FilterOps.encode(IntegerType, 10) assert(!builtExpression.execute(columnToCurrentRowValueMap, valueFromQueryValueArray)) } @@ -127,40 +128,40 @@ BeforeAndAfterEach with BeforeAndAfterAll with Logging { val valueFromQueryValueArray = new Array[Array[Byte]](1) //great than - valueFromQueryValueArray(0) = Bytes.toBytes(10) + valueFromQueryValueArray(0) = FilterOps.encode(IntegerType, 10) assert(!greaterLogic.execute(columnToCurrentRowValueMap, valueFromQueryValueArray)) - valueFromQueryValueArray(0) = Bytes.toBytes(20) + valueFromQueryValueArray(0) = FilterOps.encode(IntegerType, 20) assert(!greaterLogic.execute(columnToCurrentRowValueMap, valueFromQueryValueArray)) //great than and equal - valueFromQueryValueArray(0) = Bytes.toBytes(5) + valueFromQueryValueArray(0) = FilterOps.encode(IntegerType, 5) assert(greaterAndEqualLogic.execute(columnToCurrentRowValueMap, valueFromQueryValueArray)) - valueFromQueryValueArray(0) = Bytes.toBytes(10) + valueFromQueryValueArray(0) = FilterOps.encode(IntegerType, 10) assert(greaterAndEqualLogic.execute(columnToCurrentRowValueMap, valueFromQueryValueArray)) - valueFromQueryValueArray(0) = Bytes.toBytes(20) + valueFromQueryValueArray(0) = FilterOps.encode(IntegerType, 20) assert(!greaterAndEqualLogic.execute(columnToCurrentRowValueMap, valueFromQueryValueArray)) //less than - valueFromQueryValueArray(0) = Bytes.toBytes(10) + valueFromQueryValueArray(0) = FilterOps.encode(IntegerType, 10) assert(!lessLogic.execute(columnToCurrentRowValueMap, valueFromQueryValueArray)) - valueFromQueryValueArray(0) = Bytes.toBytes(5) + valueFromQueryValueArray(0) = FilterOps.encode(IntegerType, 5) assert(!lessLogic.execute(columnToCurrentRowValueMap, valueFromQueryValueArray)) //less than and equal - valueFromQueryValueArray(0) = Bytes.toBytes(20) + valueFromQueryValueArray(0) = FilterOps.encode(IntegerType, 20) assert(lessAndEqualLogic.execute(columnToCurrentRowValueMap, valueFromQueryValueArray)) - valueFromQueryValueArray(0) = Bytes.toBytes(20) + valueFromQueryValueArray(0) = FilterOps.encode(IntegerType, 20) assert(lessAndEqualLogic.execute(columnToCurrentRowValueMap, valueFromQueryValueArray)) - valueFromQueryValueArray(0) = Bytes.toBytes(10) + valueFromQueryValueArray(0) = FilterOps.encode(IntegerType, 10) assert(lessAndEqualLogic.execute(columnToCurrentRowValueMap, valueFromQueryValueArray)) //equal too diff --git hbase-spark/src/test/scala/org/apache/hadoop/hbase/spark/PartitoinFilterSuite.scala hbase-spark/src/test/scala/org/apache/hadoop/hbase/spark/PartitoinFilterSuite.scala new file mode 100644 index 0000000..48cff2d --- /dev/null +++ hbase-spark/src/test/scala/org/apache/hadoop/hbase/spark/PartitoinFilterSuite.scala @@ -0,0 +1,209 @@ +package org.apache.hadoop.hbase.spark + +import org.apache.hadoop.hbase.spark.datasources.HBaseSparkConf +import org.apache.hadoop.hbase.{TableName, HBaseTestingUtility} +import org.apache.spark.sql.datasources.hbase.HBaseTableCatalog +import org.apache.spark.sql.{DataFrame, SQLContext} +import org.apache.spark.{SparkConf, SparkContext, Logging} +import org.scalatest.{BeforeAndAfterAll, BeforeAndAfterEach, FunSuite} + +case class FilterRangeRecord( + col0: Integer, + col1: Boolean, + col2: Double, + col3: Float, + col4: Int, + col5: Long, + col6: Short, + col7: String, + col8: Byte) + +object FilterRangeRecord { + def apply(i: Int): FilterRangeRecord = { + FilterRangeRecord(if (i % 2 == 0) i else -i, + i % 2 == 0, + if (i % 2 == 0) i.toDouble else -i.toDouble, + i.toFloat, + if (i % 2 == 0) i else -i, + i.toLong, + i.toShort, + s"String$i extra", + i.toByte) + } +} + +class PartitionFilterSuite extends FunSuite with + BeforeAndAfterEach with BeforeAndAfterAll with Logging { + @transient var sc: SparkContext = null + var TEST_UTIL: HBaseTestingUtility = new HBaseTestingUtility + + var sqlContext: SQLContext = null + var df: DataFrame = null + + def withCatalog(cat: String): DataFrame = { + sqlContext + .read + .options(Map(HBaseTableCatalog.tableCatalog -> cat)) + .format("org.apache.hadoop.hbase.spark") + .load() + } + + override def beforeAll() { + + TEST_UTIL.startMiniCluster + val sparkConf = new SparkConf + sparkConf.set(HBaseSparkConf.BLOCK_CACHE_ENABLE, "true") + sparkConf.set(HBaseSparkConf.BATCH_NUM, "100") + sparkConf.set(HBaseSparkConf.CACHE_SIZE, "100") + + sc = new SparkContext("local", "test", sparkConf) + new HBaseContext(sc, TEST_UTIL.getConfiguration) + sqlContext = new SQLContext(sc) + } + + override def afterAll() { + logInfo("shuting down minicluster") + TEST_UTIL.shutdownMiniCluster() + + sc.stop() + } + + override def beforeEach(): Unit = { + DefaultSourceStaticUtils.lastFiveExecutionRules.clear() + } + + val catalog = s"""{ + |"table":{"namespace":"default", "name":"rangeTable"}, + |"rowkey":"key", + |"columns":{ + |"col0":{"cf":"rowkey", "col":"key", "type":"int"}, + |"col1":{"cf":"cf1", "col":"col1", "type":"boolean"}, + |"col2":{"cf":"cf2", "col":"col2", "type":"double"}, + |"col3":{"cf":"cf3", "col":"col3", "type":"float"}, + |"col4":{"cf":"cf4", "col":"col4", "type":"int"}, + |"col5":{"cf":"cf5", "col":"col5", "type":"bigint"}, + |"col6":{"cf":"cf6", "col":"col6", "type":"smallint"}, + |"col7":{"cf":"cf7", "col":"col7", "type":"string"}, + |"col8":{"cf":"cf8", "col":"col8", "type":"tinyint"} + |} + |}""".stripMargin + + test("populate rangeTable") { + val sql = sqlContext + import sql.implicits._ + + val data = (0 until 32).map { i => + FilterRangeRecord(i) + } + sc.parallelize(data).toDF.write.options( + Map(HBaseTableCatalog.tableCatalog -> catalog, HBaseTableCatalog.newTable -> "5")) + .format("org.apache.hadoop.hbase.spark") + .save() + } + + + test("rangeTable full query") { + val df = withCatalog(catalog) + df.show + assert(df.count() === 32) + } + + test("rangeTable rowkey less than 0") { + val sql = sqlContext + import sql.implicits._ + val df = withCatalog(catalog) + val s = df.filter($"col0" < 0) + s.show + assert(s.count() === 16) + } + + test("rangeTable int col less than 0") { + val sql = sqlContext + import sql.implicits._ + val df = withCatalog(catalog) + val s = df.filter($"col4" < 0) + s.show + assert(s.count() === 16) + } + + test("rangeTable double col less than 0") { + val sql = sqlContext + import sql.implicits._ + val df = withCatalog(catalog) + val s = df.filter($"col2" < 3.0) + s.show + assert(s.count() === 18) + } + + test("rangeTable lessequal than -10") { + val sql = sqlContext + import sql.implicits._ + val df = withCatalog(catalog) + val s = df.filter($"col0" <= -10) + s.show + assert(s.count() === 11) + } + + test("rangeTable lessequal than -9") { + val sql = sqlContext + import sql.implicits._ + val df = withCatalog(catalog) + val s = df.filter($"col0" <= -9) + s.show + assert(s.count() === 12) + } + + test("rangeTable greaterequal than -9") { + val sql = sqlContext + import sql.implicits._ + val df = withCatalog(catalog) + val s = df.filter($"col0" >= -9) + s.show + assert(s.count() === 21) + } + + test("rangeTable greaterequal than 0") { + val sql = sqlContext + import sql.implicits._ + val df = withCatalog(catalog) + val s = df.filter($"col0" >= 0) + s.show + assert(s.count() === 16) + } + + test("rangeTable greater than 10") { + val sql = sqlContext + import sql.implicits._ + val df = withCatalog(catalog) + val s = df.filter($"col0" > 10) + s.show + assert(s.count() === 10) + } + + test("rangeTable and") { + val sql = sqlContext + import sql.implicits._ + val df = withCatalog(catalog) + val s = df.filter($"col0" > -10 && $"col0" <= 10) + s.show + assert(s.count() === 11) + } + + test("or") { + val sql = sqlContext + import sql.implicits._ + val df = withCatalog(catalog) + val s = df.filter($"col0" <= -10 || $"col0" > 10) + s.show + assert(s.count() === 21) + } + + test("rangeTable all") { + val sql = sqlContext + import sql.implicits._ + val df = withCatalog(catalog) + val s = df.filter($"col0" >= -100) + s.show + assert(s.count() === 32) + } +}