diff --git a/hbase-spark/pom.xml b/hbase-spark/pom.xml index e48f9e8..7417127 100644 --- a/hbase-spark/pom.xml +++ b/hbase-spark/pom.xml @@ -79,6 +79,12 @@ org.apache.spark + spark-sql_${scala.binary.version} + ${spark.version} + provided + + + org.apache.spark spark-streaming_${scala.binary.version} ${spark.version} diff --git a/hbase-spark/src/main/java/org/apache/hadoop/hbase/spark/SparkSQLPushDownFilter.java b/hbase-spark/src/main/java/org/apache/hadoop/hbase/spark/SparkSQLPushDownFilter.java new file mode 100644 index 0000000..01e73b9 --- /dev/null +++ b/hbase-spark/src/main/java/org/apache/hadoop/hbase/spark/SparkSQLPushDownFilter.java @@ -0,0 +1,160 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.hadoop.hbase.spark; + +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; +import org.apache.hadoop.hbase.Cell; +import org.apache.hadoop.hbase.exceptions.DeserializationException; +import org.apache.hadoop.hbase.filter.Filter; +import org.apache.hadoop.hbase.filter.FilterBase; + +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.ObjectInput; +import java.io.ObjectInputStream; +import java.io.ObjectOutput; +import java.io.ObjectOutputStream; +import java.io.Serializable; +import java.util.HashMap; + +/** + * This filter will push down all qualifier logic given to us + * by SparkSQL so that we have make the filters at the region server level + * and avoid sending the data back to the client to be filtered. + */ +public class SparkSQLPushDownFilter extends FilterBase implements Serializable{ + protected static final Log log = LogFactory.getLog(SparkSQLPushDownFilter.class); + + HashMap columnFamilyQualifierFilterMap; + + private Filter filter; + + public SparkSQLPushDownFilter(Filter filter) { + this.filter = filter; + } + + public Filter getFilter() { + return filter; + } + + public SparkSQLPushDownFilter(HashMap columnFamilyQualifierFilterMap) { + this.columnFamilyQualifierFilterMap = columnFamilyQualifierFilterMap; + } + + /** + * This method will find the related filter logic for the given + * column family and qualifier then execute it. It will also + * not clone the in coming cell to avoid extra object creation + * + * @param c The Cell to be validated + * @return ReturnCode object to determine if skipping is required + * @throws IOException + */ + @Override + public ReturnCode filterKeyValue(Cell c) throws IOException { + + //Get filter if one exist + ColumnFilter filter = + columnFamilyQualifierFilterMap.get(new ColumnFamilyQualifierMapKeyWrapper( + c.getFamilyArray(), + c.getFamilyOffset(), + c.getFamilyLength(), + c.getQualifierArray(), + c.getQualifierOffset(), + c.getQualifierLength())); + + if (filter == null) { + //If no filter then just include values + return ReturnCode.INCLUDE; + } else { + //If there is a filter then run validation + if (filter.validate(c.getValueArray(), c.getValueOffset(), c.getValueLength())) { + return ReturnCode.INCLUDE; + } else { + //If validation fails then skill whole row + return ReturnCode.NEXT_ROW; + } + } + } + + /** + * Used to construct the object from a byte array on the Region Server side + * + * @param bytes object represented as a byte array + * @return a new PushDownFilter object + * @throws DeserializationException + */ + public static SparkSQLPushDownFilter parseFrom(final byte [] bytes) + throws DeserializationException { + + SparkSQLPushDownFilter result; + + ByteArrayInputStream bis = new ByteArrayInputStream(bytes); + ObjectInput in = null; + try { + in = new ObjectInputStream(bis); + result = (SparkSQLPushDownFilter)in.readObject(); + } catch (Exception e) { + throw new DeserializationException(e); + } finally { + try { + bis.close(); + } catch (IOException ex) { + log.error("Error while trying to parseFrom method: ", ex); + } + try { + if (in != null) { + in.close(); + } + } catch (IOException ex) { + log.error("Error while trying to parseFrom method: ", ex); + } + } + return result; + } + + /** + * @return The filter serialized using pb + */ + public byte [] toByteArray() throws IOException{ + ByteArrayOutputStream bos = new ByteArrayOutputStream(); + ObjectOutput out = null; + try { + out = new ObjectOutputStream(bos); + out.writeObject(this); + return bos.toByteArray(); + } finally { + try { + if (out != null) { + out.close(); + } + } catch (IOException ex) { + // ignore close exception + } + try { + bos.close(); + } catch (IOException ex) { + // ignore close exception + } + } + } + +} diff --git a/hbase-spark/src/main/scala/org/apache/hadoop/hbase/spark/ColumnFamilyQualifierMapKeyWrapper.scala b/hbase-spark/src/main/scala/org/apache/hadoop/hbase/spark/ColumnFamilyQualifierMapKeyWrapper.scala new file mode 100644 index 0000000..ea6675a --- /dev/null +++ b/hbase-spark/src/main/scala/org/apache/hadoop/hbase/spark/ColumnFamilyQualifierMapKeyWrapper.scala @@ -0,0 +1,61 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.hadoop.hbase.spark + +import org.apache.hadoop.hbase.util.Bytes + +/** + * A wrapper class that will allow both columnFamily and qualifier to + * be the key of a hashMap. Also allow for finding the value in a hashmap + * with out cloning the HBase value from the HBase Cell object + * @param columnFamily ColumnFamily byte array + * @param columnFamilyOffSet Offset of columnFamily value in the array + * @param columnFamilyLength Length of the columnFamily value in the columnFamily array + * @param qualifier Qualifier byte array + * @param qualifierOffSet Offset of qualifier value in the array + * @param qualifierLength Length of the qualifier value with in the array + */ +class ColumnFamilyQualifierMapKeyWrapper(val columnFamily:Array[Byte], + val columnFamilyOffSet:Int, + val columnFamilyLength:Int, + val qualifier:Array[Byte], + val qualifierOffSet:Int, + val qualifierLength:Int) + extends Serializable{ + + override def equals(other:Any): Boolean = { + val otherWrapper = other.asInstanceOf[ColumnFamilyQualifierMapKeyWrapper] + + Bytes.compareTo(columnFamily, + columnFamilyOffSet, + columnFamilyLength, + otherWrapper.columnFamily, + otherWrapper.columnFamilyOffSet, + otherWrapper.columnFamilyLength) == 0 && Bytes.compareTo(qualifier, + qualifierOffSet, + qualifierLength, + otherWrapper.qualifier, + otherWrapper.qualifierOffSet, + otherWrapper.qualifierLength) == 0 + } + + override def hashCode():Int = { + Bytes.hashCode(columnFamily, columnFamilyOffSet, columnFamilyLength) + + Bytes.hashCode(qualifier, qualifierOffSet, qualifierLength) + } +} diff --git a/hbase-spark/src/main/scala/org/apache/hadoop/hbase/spark/DefaultSource.scala b/hbase-spark/src/main/scala/org/apache/hadoop/hbase/spark/DefaultSource.scala new file mode 100644 index 0000000..8d6fb4d --- /dev/null +++ b/hbase-spark/src/main/scala/org/apache/hadoop/hbase/spark/DefaultSource.scala @@ -0,0 +1,926 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.hadoop.hbase.spark + +import java.sql.Timestamp +import java.util +import java.util.concurrent.ConcurrentLinkedQueue + +import org.apache.hadoop.hbase.client.{ConnectionFactory, Get, Result, Scan} +import org.apache.hadoop.hbase.util.Bytes +import org.apache.hadoop.hbase.{TableName, HBaseConfiguration} +import org.apache.spark.Logging +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.{Row, SQLContext} +import org.apache.spark.sql.sources._ +import org.apache.spark.sql.types._ + +import scala.collection.mutable + +/** + * DefaultSource for integration with Spark's dataframe datasources. + * This class with produce a relationProvider based on input give to it from spark + * + * In all this DefaultSource support the following datasource functionality + * - Scan range pruning through filter push down logic based on rowKeys + * - Filter push down logic on columns that are not rowKey columns + * - Qualifier filtering based on columns used in the SparkSQL statement + * - Type conversions of basic SQL types + */ +class DefaultSource extends RelationProvider { + + val TABLE_KEY:String = "hbase.table" + val SCHEMA_COLUMNS_MAPPING_KEY:String = "hbase.columns.mapping" + val BATCHING_NUM_KEY:String = "hbase.batching.num" + val CACHING_NUM_KEY:String = "hbase.caching.num" + val HBASE_CONFIG_RESOURCES_LOCATIONS:String = "hbase.config.resources" + val USE_HBASE_CONTEXT:String = "hbase.use.hbase.context" + + /** + * Is given input from SparkSQL to construct a BaseRelation + * @param sqlContext SparkSQL context + * @param parameters Parameters given to us from SparkSQL + * @return A BaseRelation Object + */ + override def createRelation(sqlContext: SQLContext, + parameters: Map[String, String]): + BaseRelation = { + + + val tableName = parameters.get(TABLE_KEY) + if (tableName.isEmpty) + new Throwable("Invalid value for " + TABLE_KEY +" '" + tableName + "'") + + val schemaMappingString = parameters.getOrElse(SCHEMA_COLUMNS_MAPPING_KEY, "") + val batchingNumStr = parameters.getOrElse(BATCHING_NUM_KEY, "1000") + val cachingNumStr = parameters.getOrElse(CACHING_NUM_KEY, "1000") + val hbaseConfigResources = parameters.getOrElse(HBASE_CONFIG_RESOURCES_LOCATIONS, "") + val useHBaseReources = parameters.getOrElse(USE_HBASE_CONTEXT, "true") + + val batchingNum:Int = try { + batchingNumStr.toInt + } catch { + case e:Exception => throw + new Throwable("Invalid value for " + BATCHING_NUM_KEY +" '" + batchingNumStr + "'", e ) + } + + val cachingNum:Int = try { + cachingNumStr.toInt + } catch { + case e:Exception => throw + new Throwable("Invalid value for " + CACHING_NUM_KEY +" '" + cachingNumStr + "'", e ) + } + + new HBaseRelation(tableName.get, + generateSchemaMappingMap(schemaMappingString), + batchingNum.toInt, + cachingNum.toInt, + hbaseConfigResources, + useHBaseReources.equalsIgnoreCase("true"))(sqlContext) + } + + /** + * Reads the SCHEMA_COLUMNS_MAPPING_KEY and converts it to a map of + * SchemaQualifierDefinitions with the original sql column name as the key + * @param schemaMappingString The schema mapping string from the SparkSQL map + * @return A map of definitions keyed by the SparkSQL column name + */ + def generateSchemaMappingMap(schemaMappingString:String): + java.util.HashMap[String, SchemaQualifierDefinition] = { + try { + val columnDefinitions = schemaMappingString.split(',') + val resultingMap = new java.util.HashMap[String, SchemaQualifierDefinition]() + columnDefinitions.map(cd => { + val parts = cd.trim.split(' ') + + //Make sure we get three parts + // + if (parts.length == 3) { + val hbaseDefinitionParts = if (parts(2).charAt(0) == ':') { + Array[String]("", "key") + } else { + parts(2).split(':') + } + resultingMap.put(parts(0), new SchemaQualifierDefinition(parts(0), + parts(1), hbaseDefinitionParts(0), hbaseDefinitionParts(1))) + } else { + throw new Throwable("Invalid value for schema mapping '" + cd + + "' should be ' :' " + + "for columns and ' :' for rowKeys") + } + }) + resultingMap + } catch { + case e:Exception => throw + new Throwable("Invalid value for " + SCHEMA_COLUMNS_MAPPING_KEY + + " '" + schemaMappingString + "'", e ) + } + } +} + +/** + * Implementation of Spark BaseRelation that will build up our scan logic + * , do the scan pruning, filter push down, and value conversions + * + * @param tableName HBase table that we plan to read from + * @param schemaMappingDefinition SchemaMapping information to map HBase + * Qualifiers to SparkSQL columns + * @param batchingNum The batching number to be applied to the + * scan object + * @param cachingNum The caching number to be applied to the + * scan object + * @param configResources Optional comma separated list of config resources + * to get based on their URI + * @param useHBaseContext If true this will look to see if + * HBaseContext.latest is populated to use that + * connection information + * @param sqlContext SparkSQL context + */ +class HBaseRelation (val tableName:String, + val schemaMappingDefinition:java.util.HashMap[String, SchemaQualifierDefinition], + val batchingNum:Int, + val cachingNum:Int, + val configResources:String, + val useHBaseContext:Boolean) ( + @transient val sqlContext:SQLContext) + extends BaseRelation with PrunedFilteredScan with Logging { + + //create or get latest HBaseContext + @transient val hbaseContext:HBaseContext = if (useHBaseContext) { + LatestHBaseContextCache.latest + } else { + val config = HBaseConfiguration.create() + configResources.split(",").foreach( r => config.addResource(r)) + new HBaseContext(sqlContext.sparkContext, config) + } + + /** + * Generates a Spark SQL schema object so Spark SQL knows what is being + * provided by this BaseRelation + * + * @return schema generated from the SCHEMA_COLUMNS_MAPPING_KEY value + */ + override def schema: StructType = { + + val metadataBuilder = new MetadataBuilder() + + val structFieldArray = new Array[StructField](schemaMappingDefinition.size()) + + val schemaMappingDefinitionIt = schemaMappingDefinition.values().iterator() + var indexCounter = 0 + while (schemaMappingDefinitionIt.hasNext) { + val c = schemaMappingDefinitionIt.next() + + val metadata = metadataBuilder.putString("name", c.columnName).build() + val struckField = + new StructField(c.columnName, c.columnSparkSqlType, nullable = true, metadata) + + structFieldArray(indexCounter) = struckField + indexCounter += 1 + } + + val result = new StructType(structFieldArray) + //TODO push schema to listener + result + } + + /** + * Here we are building the functionality to populate the resulting RDD[Row] + * Here is where we will do the following: + * - Filter push down + * - Scan or GetList pruning + * - Executing our scan(s) or/and GetList to generate result + * + * @param requiredColumns The columns that are being requested by the requesting query + * @param filters The filters that are being applied by the requesting query + * @return RDD will all the results from HBase needed for SparkSQL to + * execute the query on + */ + override def buildScan(requiredColumns: Array[String], filters: Array[Filter]): RDD[Row] = { + + val columnFilterCollection = buildColumnFilterCollection(filters) + + val requiredQualifierDefinitionArray = new mutable.MutableList[SchemaQualifierDefinition] + requiredColumns.foreach( c => { + val definition = schemaMappingDefinition.get(c) + if (definition.columnFamilyBytes.length > 0) { + requiredQualifierDefinitionArray += definition + } + }) + + //Create a local variable so that scala doesn't have to + // serialize the whole HBaseRelation Object + val serializableDefinitionMap = schemaMappingDefinition + + + //retain the information for unit testing checks + DefaultSourceStaticUtils.populateLatestExecutionRules(columnFilterCollection, + requiredQualifierDefinitionArray) + + var resultRDD: RDD[Row] = null + + if (columnFilterCollection != null) { + val pushDownFilterJava = + new SparkSQLPushDownFilter( + columnFilterCollection.generateFamilyQualifiterFilterMap(schemaMappingDefinition)) + + val getList = new util.ArrayList[Get]() + val rddList = new util.ArrayList[RDD[Row]]() + + val it = columnFilterCollection.columnFilterMap.iterator + while (it.hasNext) { + val e = it.next() + val columnDefinition = schemaMappingDefinition.get(e._1) + //check is a rowKey + if (columnDefinition != null && columnDefinition.columnFamily.isEmpty) { + //add points to getList + e._2.points.foreach(p => { + val get = new Get(p) + requiredQualifierDefinitionArray.foreach( d => get.addColumn(d.columnFamilyBytes, d.qualifierBytes)) + getList.add(get) + }) + + val rangeIt = e._2.ranges.iterator + + while (rangeIt.hasNext) { + val r = rangeIt.next() + + val scan = new Scan() + scan.setBatch(batchingNum) + scan.setCaching(cachingNum) + requiredQualifierDefinitionArray.foreach( d => scan.addColumn(d.columnFamilyBytes, d.qualifierBytes)) + + if (pushDownFilterJava.columnFamilyQualifierFilterMap.size() > 0) { + scan.setFilter(pushDownFilterJava) + } + + //Check if there is a lower bound + if (r.lowerBound != null && r.lowerBound.length > 0) { + + if (r.isLowerBoundEqualTo) { + //HBase startRow is inclusive: Therefore it acts like isLowerBoundEqualTo + // by default + scan.setStartRow(r.lowerBound) + } else { + //Since we don't equalTo we want the next value we need + // to add another byte to the start key. That new byte will be + // the min byte value. + val newArray = new Array[Byte](r.lowerBound.length + 1) + System.arraycopy(r.lowerBound, 0, newArray, 0, r.lowerBound.length) + + //new Min Byte + newArray(r.lowerBound.length) = Byte.MinValue + scan.setStartRow(newArray) + } + } + + //Check if there is a upperBound + if (r.upperBound != null && r.upperBound.length > 0) { + if (r.isUpperBoundEqualTo) { + //HBase stopRow is exclusive: therefore it DOESN'T ast like isUpperBoundEqualTo + // by default. So we need to add a new max byte to the stopRow key + val newArray = new Array[Byte](r.upperBound.length + 1) + System.arraycopy(r.upperBound, 0, newArray, 0, r.upperBound.length) + + //New Max Bytes + newArray(r.upperBound.length) = Byte.MaxValue + + scan.setStopRow(newArray) + } else { + //Here equalTo is false for Upper bound which is exclusive and + // HBase stopRow acts like that by default so no need to mutate the + // rowKey + scan.setStopRow(r.upperBound) + } + } + + val rdd = hbaseContext.hbaseRDD(TableName.valueOf(tableName), scan).map(r => { + Row.fromSeq(requiredColumns.map(c => + DefaultSourceStaticUtils.getValue(c, serializableDefinitionMap, r._2))) + }) + rddList.add(rdd) + } + } + } + + //If there is more then one RDD then we have to union them together + for (i <- 0 until rddList.size()) { + if (resultRDD == null) resultRDD = rddList.get(i) + else resultRDD = resultRDD.union(rddList.get(i)) + + } + + //If there are gets then we can get them from the driver and union that rdd in + // with the rest of the values. + if (getList.size() > 0) { + val connection = ConnectionFactory.createConnection(hbaseContext.tmpHdfsConfiguration) + val table = connection.getTable(TableName.valueOf(tableName)) + try { + val results = table.get(getList) + val rowList = mutable.MutableList[Row]() + for (i <- 0 until results.length) { + val rowArray = requiredColumns.map(c => + DefaultSourceStaticUtils.getValue(c, schemaMappingDefinition, results(i))) + rowList += Row.fromSeq(rowArray) + } + val getRDD = sqlContext.sparkContext.parallelize(rowList) + if (resultRDD == null) resultRDD = getRDD + else { + resultRDD = resultRDD.union(getRDD) + } + } finally { + try { + connection.close() + } finally { + table.close() + } + } + + } + } + if (resultRDD == null) { + val scan = new Scan() + scan.setBatch(batchingNum) + scan.setCaching(cachingNum) + requiredQualifierDefinitionArray.foreach( d => scan.addColumn(d.columnFamilyBytes, d.qualifierBytes)) + + val rdd = hbaseContext.hbaseRDD(TableName.valueOf(tableName), scan).map(r => { + Row.fromSeq(requiredColumns.map(c => DefaultSourceStaticUtils.getValue(c, + serializableDefinitionMap, r._2))) + }) + resultRDD=rdd + } + resultRDD + } + + /** + * Root recursive function that will loop over the filters provided by + * SparkSQL. Some filters are AND or OR functions and contain additional filters + * hence the need for recursion. + * + * @param filters Filters provided by SparkSQL. + * Filters are joined with the AND operater + * @return A ColumnFilterCollection whish is a consolidated construct to + * hold the high level filter information + */ + def buildColumnFilterCollection(filters: Array[Filter]): ColumnFilterCollection = { + var superCollection: ColumnFilterCollection = null + + filters.foreach( f => { + val parentCollection = new ColumnFilterCollection + buildColumnFilterCollection(parentCollection, f) + if (superCollection == null) superCollection = parentCollection + else {superCollection.mergeIntersect(parentCollection)} + }) + superCollection + } + + /** + * Recursive function that will work to convert Spark Filter objects to ColumnFilterCollection + * + * @param parentFilterCollection Parent ColumnFilterCollection + * @param filter Current given filter from SparkSQL + */ + def buildColumnFilterCollection(parentFilterCollection:ColumnFilterCollection, + filter:Filter): Unit = { + filter match { + + case EqualTo(attr, value) => + parentFilterCollection.mergeUnion(attr, + new ColumnFilter(DefaultSourceStaticUtils.getByteValue(attr, + schemaMappingDefinition, value.toString))) + + case LessThan(attr, value) => + parentFilterCollection.mergeUnion(attr, new ColumnFilter(null, + new ScanRange(DefaultSourceStaticUtils.getByteValue(attr, + schemaMappingDefinition, value.toString), false, + new Array[Byte](0), true))) + + case GreaterThan(attr, value) => + parentFilterCollection.mergeUnion(attr, new ColumnFilter(null, + new ScanRange(null, true, DefaultSourceStaticUtils.getByteValue(attr, + schemaMappingDefinition, value.toString), false))) + + case LessThanOrEqual(attr, value) => + parentFilterCollection.mergeUnion(attr, new ColumnFilter(null, + new ScanRange(DefaultSourceStaticUtils.getByteValue(attr, + schemaMappingDefinition, value.toString), true, + new Array[Byte](0), true))) + + case GreaterThanOrEqual(attr, value) => + parentFilterCollection.mergeUnion(attr, new ColumnFilter(null, + new ScanRange(null, true, DefaultSourceStaticUtils.getByteValue(attr, + schemaMappingDefinition, value.toString), true))) + + case Or(left, right) => + buildColumnFilterCollection(parentFilterCollection, left) + val rightSideCollection = new ColumnFilterCollection + buildColumnFilterCollection(rightSideCollection, right) + parentFilterCollection.mergeUnion(rightSideCollection) + case And(left, right) => + buildColumnFilterCollection(parentFilterCollection, left) + val rightSideCollection = new ColumnFilterCollection + buildColumnFilterCollection(rightSideCollection, right) + parentFilterCollection.mergeIntersect(rightSideCollection) + case _ => //nothing + } + } +} + +/** + * Construct to contains column data that spend SparkSQL and HBase + * + * @param columnName SparkSQL column name + * @param colType SparkSQL column type + * @param columnFamily HBase column family + * @param qualifier HBase qualifier name + */ +case class SchemaQualifierDefinition(columnName:String, + colType:String, + columnFamily:String, + qualifier:String) extends Serializable { + val columnFamilyBytes = Bytes.toBytes(columnFamily) + val qualifierBytes = Bytes.toBytes(qualifier) + val columnSparkSqlType = if (colType.equals("BOOLEAN")) BooleanType + else if (colType.equals("TINYINT")) IntegerType + else if (colType.equals("INT")) IntegerType + else if (colType.equals("BIGINT")) LongType + else if (colType.equals("FLOAT")) FloatType + else if (colType.equals("DOUBLE")) DoubleType + else if (colType.equals("STRING")) StringType + else if (colType.equals("TIMESTAMP")) TimestampType + else if (colType.equals("DECIMAL")) StringType //DataTypes.createDecimalType(precision, scale) + else throw new Throwable("Unsupported column type :" + colType) +} + +/** + * Construct to contain a single scan ranges information. Also + * provide functions to merge with other scan ranges through AND + * or OR operators + * + * @param upperBound Upper bound of scan + * @param isUpperBoundEqualTo Include upper bound value in the results + * @param lowerBound Lower bound of scan + * @param isLowerBoundEqualTo Include lower bound value in the results + */ +class ScanRange(var upperBound:Array[Byte], var isUpperBoundEqualTo:Boolean, + var lowerBound:Array[Byte], var isLowerBoundEqualTo:Boolean) + extends Serializable { + + /** + * Function to merge another scan object through a AND operation + * @param other Other scan object + */ + def mergeIntersect(other:ScanRange): Unit = { + val upperBoundCompare = compareRange(upperBound, other.upperBound) + val lowerBoundCompare = compareRange(lowerBound, other.lowerBound) + + upperBound = if (upperBoundCompare <0) upperBound else other.upperBound + lowerBound = if (lowerBoundCompare >0) lowerBound else other.lowerBound + + isLowerBoundEqualTo = if (lowerBoundCompare == 0) + isLowerBoundEqualTo && other.isLowerBoundEqualTo + else isLowerBoundEqualTo + + isUpperBoundEqualTo = if (upperBoundCompare == 0) + isUpperBoundEqualTo && other.isUpperBoundEqualTo + else isUpperBoundEqualTo + } + + /** + * Function to merge another scan object through a OR operation + * @param other Other scan object + */ + def mergeUnion(other:ScanRange): Unit = { + + val upperBoundCompare = compareRange(upperBound, other.upperBound) + val lowerBoundCompare = compareRange(lowerBound, other.lowerBound) + + upperBound = if (upperBoundCompare >0) upperBound else other.upperBound + lowerBound = if (lowerBoundCompare <0) lowerBound else other.lowerBound + + isLowerBoundEqualTo = if (lowerBoundCompare == 0) + isLowerBoundEqualTo || other.isLowerBoundEqualTo + else isLowerBoundEqualTo + + isUpperBoundEqualTo = if (upperBoundCompare == 0) + isUpperBoundEqualTo || other.isUpperBoundEqualTo + else isUpperBoundEqualTo + } + + /** + * Common function to see if this scan over laps with another + * + * Reference Visual + * + * A B + * |---------------------------| + * LL--------------LU + * RL--------------RU + * + * A = lowest value is byte[0] + * B = highest value is null + * LL = Left Lower Bound + * LU = Left Upper Bound + * RL = Right Lower Bound + * RU = Right Upper Bound + * + * @param other Other scan object + * @return True is overlap false is not overlap + */ + def doesOverLap(other:ScanRange): Boolean = { + + var leftRange:ScanRange = null + var rightRange:ScanRange = null + + //First identify the Left range + // Also lower bound can't be null + if (Bytes.compareTo(lowerBound, other.lowerBound) <=0) { + leftRange = this + rightRange = other + } else { + leftRange = other + rightRange = this + } + + //Then see if leftRange goes to null or if leftRange.upperBound + // upper is greater or equals to rightRange.lowerBound + leftRange.upperBound == null || + Bytes.compareTo(leftRange.upperBound, rightRange.lowerBound) >= 0 + } + + /** + * Special compare logic because we can have null values + * for left or right bound + * + * @param left Left byte array + * @param right Right byte array + * @return 0 for equals 1 is left is greater and -1 is right is greater + */ + def compareRange(left:Array[Byte], right:Array[Byte]): Int = { + if (left == null && right == null) 0 + else if (left == null && right != null) 1 + else if (left != null && right == null) -1 + else Bytes.compareTo(left, right) + } + + override def toString:String = { + "ScanRange:(" + Bytes.toString(upperBound) + "," + isUpperBoundEqualTo + "," + + Bytes.toString(lowerBound) + "," + isLowerBoundEqualTo + ")" + } +} + +/** + * Contains information related to a filters for a given column. + * This can contain many ranges or points. + * + * @param currentPoint the initial point when the filter is created + * @param currentRange the initial scanRange when the filter is created + */ +class ColumnFilter (currentPoint:Array[Byte] = null, + currentRange:ScanRange = null) extends Serializable { + //Collection of ranges + var ranges = new mutable.MutableList[ScanRange]() + if (currentRange != null ) ranges.+=(currentRange) + + //Collection of points + var points = new mutable.MutableList[Array[Byte]]() + if (currentPoint != null) { + points.+=(currentPoint) + } + + /** + * This will validate a give value through the filter's points and/or ranges + * the result will be if the value passed the filter + * + * @param value Value to be validated + * @param valueOffSet The offset of the value + * @param valueLength The length of the value + * @return True is the value passes the filter false if not + */ + def validate(value:Array[Byte], valueOffSet:Int, valueLength:Int):Boolean = { + var result = false + + points.foreach( p => { + if (Bytes.equals(p, 0, p.length, value, valueOffSet, valueLength)) { + result = true + } + }) + + ranges.foreach( r => { + val upperBoundPass = r.upperBound == null || + (r.isUpperBoundEqualTo && + Bytes.compareTo(r.upperBound, 0, r.upperBound.length, + value, valueOffSet, valueLength) >= 0) || + (!r.isUpperBoundEqualTo && + Bytes.compareTo(r.upperBound, 0, r.upperBound.length, + value, valueOffSet, valueLength) > 0) + + val lowerBoundPass = r.lowerBound == null || r.lowerBound.length == 0 + (r.isLowerBoundEqualTo && + Bytes.compareTo(r.lowerBound, 0, r.lowerBound.length, + value, valueOffSet, valueLength) <= 0) || + (!r.isLowerBoundEqualTo && + Bytes.compareTo(r.lowerBound, 0, r.lowerBound.length, + value, valueOffSet, valueLength) < 0) + + result = result || (upperBoundPass && lowerBoundPass) + }) + result + } + + /** + * This will allow us to merge filter logic that is joined to the existing filter + * through a OR operator + * + * @param other Filter to merge + */ + def mergeUnion(other:ColumnFilter): Unit = { + other.points.foreach( p => points += p) + + other.ranges.foreach( otherR => { + var doesOverLap = false + ranges.foreach{ r => + if (r.doesOverLap(otherR)) { + r.mergeUnion(otherR) + doesOverLap = true + }} + if (!doesOverLap) ranges.+=(otherR) + }) + } + + /** + * This will allow us to merge filter logic that is joined to the existing filter + * through a AND operator + * + * @param other Filter to merge + */ + def mergeIntersect(other:ColumnFilter): Unit = { + val survivingPoints = new mutable.MutableList[Array[Byte]]() + points.foreach( p => { + other.points.foreach( otherP => { + if (Bytes.equals(p, otherP)) { + survivingPoints.+=(p) + } + }) + }) + points = survivingPoints + + val survivingRanges = new mutable.MutableList[ScanRange]() + + other.ranges.foreach( otherR => { + ranges.foreach( r => { + if (r.doesOverLap(otherR)) { + r.mergeIntersect(otherR) + survivingRanges += r + } + }) + }) + ranges = survivingRanges + } + + override def toString:String = { + val strBuilder = new StringBuilder + strBuilder.append("(points:(") + var isFirst = true + points.foreach( p => { + if (isFirst) isFirst = false + else strBuilder.append(",") + strBuilder.append(Bytes.toString(p)) + }) + strBuilder.append("),ranges:") + isFirst = true + ranges.foreach( r => { + if (isFirst) isFirst = false + else strBuilder.append(",") + strBuilder.append(r) + }) + strBuilder.append("))") + strBuilder.toString() + } +} + +/** + * A collection of ColumnFilters indexed by column names. + * + * Also contains merge commends that will consolidate the filters + * per column name + */ +class ColumnFilterCollection { + val columnFilterMap = new mutable.HashMap[String, ColumnFilter] + + def clear(): Unit = { + columnFilterMap.clear() + } + + /** + * This will allow us to merge filter logic that is joined to the existing filter + * through a OR operator. This will merge a single columns filter + * + * @param column The column to be merged + * @param other The other ColumnFilter object to merge + */ + def mergeUnion(column:String, other:ColumnFilter): Unit = { + val existingFilter = columnFilterMap.get(column) + if (existingFilter.isEmpty) { + columnFilterMap.+=((column, other)) + } else { + existingFilter.get.mergeUnion(other) + } + } + + /** + * This will allow us to merge all filters in the existing collection + * to the filters in the other collection. All merges are done as a result + * of a OR operator + * + * @param other The other Column Filter Collection to be merged + */ + def mergeUnion(other:ColumnFilterCollection): Unit = { + other.columnFilterMap.foreach( e => { + mergeUnion(e._1, e._2) + }) + } + + /** + * This will allow us to merge all filters in the existing collection + * to the filters in the other collection. All merges are done as a result + * of a AND operator + * + * @param other The column filter from the other collection + */ + def mergeIntersect(other:ColumnFilterCollection): Unit = { + other.columnFilterMap.foreach( e => { + val existingColumnFilter = columnFilterMap.get(e._1) + if (existingColumnFilter.isEmpty) { + columnFilterMap += e + } else { + existingColumnFilter.get.mergeIntersect(e._2) + } + }) + } + + /** + * This will collect all the filter information in a way that is optimized + * for the HBase filter commend. Allowing the filter to be accessed + * with columnFamily and qualifier information + * + * @param schemaDefinitionMap Schema Map that will help us map the right filters + * to the correct columns + * @return HashMap oc column filters + */ + def generateFamilyQualifiterFilterMap(schemaDefinitionMap: + java.util.HashMap[String, SchemaQualifierDefinition]): + util.HashMap[ColumnFamilyQualifierMapKeyWrapper, ColumnFilter] = { + val familyQualifierFilterMap = + new util.HashMap[ColumnFamilyQualifierMapKeyWrapper, ColumnFilter]() + + columnFilterMap.foreach( e => { + val definition = schemaDefinitionMap.get(e._1) + //Don't add rowKeyFilter + if (definition.columnFamilyBytes.size > 0) { + familyQualifierFilterMap.put( + new ColumnFamilyQualifierMapKeyWrapper( + definition.columnFamilyBytes, 0, definition.columnFamilyBytes.length, + definition.qualifierBytes, 0, definition.qualifierBytes.length), e._2) + } + }) + familyQualifierFilterMap + } + + override def toString:String = { + val strBuilder = new StringBuilder + columnFilterMap.foreach( e => strBuilder.append(e)) + strBuilder.toString() + } +} + +/** + * Status object to store static functions but also to hold last executed + * information that can be used for unit testing. + */ +object DefaultSourceStaticUtils { + + //This will contain the last 5 filters and required fields used in buildScan + // These values can be used in unit testing to make sure we are converting + // The Spark SQL input correctly + val lastFiveExecutionRules = + new ConcurrentLinkedQueue[ExecutionRuleForUnitTesting]() + + /** + * This method is to populate the lastFiveExecutionRules for unit test perposes + * This method is not thread safe. + * + * @param columnFilterCollection The filters in the last job + * @param requiredQualifierDefinitionArray The required columns in the last job + */ + def populateLatestExecutionRules(columnFilterCollection: ColumnFilterCollection, + requiredQualifierDefinitionArray: mutable.MutableList[SchemaQualifierDefinition]):Unit = { + lastFiveExecutionRules.add(new ExecutionRuleForUnitTesting( + columnFilterCollection, requiredQualifierDefinitionArray)) + while (lastFiveExecutionRules.size() > 5) { + lastFiveExecutionRules.poll() + } + } + + /** + * This method will convert the result content from HBase into the + * SQL value type that is requested by the Spark SQL schema definition + * + * @param columnName The name of the SparkSQL Column + * @param schemaMappingDefinition The schema definition map + * @param r The result object from HBase + * @return The converted object type + */ + def getValue(columnName: String, + schemaMappingDefinition: java.util.HashMap[String, SchemaQualifierDefinition], + r: Result): Any = { + + val columnDef = schemaMappingDefinition.get(columnName) + + if (columnDef == null) throw new Throwable("Unknown column:" + columnName) + + + if (columnDef.columnFamilyBytes.isEmpty) { + val roKey = r.getRow + + columnDef.columnSparkSqlType match { + case IntegerType => Bytes.toInt(roKey) + case LongType => Bytes.toLong(roKey) + case FloatType => Bytes.toFloat(roKey) + case DoubleType => Bytes.toDouble(roKey) + case StringType => Bytes.toString(roKey) + case TimestampType => new Timestamp(Bytes.toLong(roKey)) + case _ => Bytes.toString(roKey) + } + } else { + val cellByteValue = + r.getColumnLatestCell(columnDef.columnFamilyBytes, columnDef.qualifierBytes) + if (cellByteValue == null) null + else columnDef.columnSparkSqlType match { + case IntegerType => Bytes.toInt(cellByteValue.getValueArray, + cellByteValue.getValueOffset, cellByteValue.getValueLength) + case LongType => Bytes.toLong(cellByteValue.getValueArray, + cellByteValue.getValueOffset, cellByteValue.getValueLength) + case FloatType => Bytes.toFloat(cellByteValue.getValueArray, + cellByteValue.getValueOffset) + case DoubleType => Bytes.toDouble(cellByteValue.getValueArray, + cellByteValue.getValueOffset) + case StringType => Bytes.toString(cellByteValue.getValueArray, + cellByteValue.getValueOffset, cellByteValue.getValueLength) + case TimestampType => new Timestamp(Bytes.toLong(cellByteValue.getValueArray, + cellByteValue.getValueOffset, cellByteValue.getValueLength)) + case _ => Bytes.toString(cellByteValue.getValueArray, + cellByteValue.getValueOffset, cellByteValue.getValueLength) + } + } + } + + /** + * This will convert the value from SparkSQL to be stored into HBase using the + * right byte Type + * + * @param columnName SparkSQL column name + * @param schemaMappingDefinition Schema definition map + * @param value String value from SparkSQL + * @return Returns the byte array to go into HBase + */ + def getByteValue(columnName: String, + schemaMappingDefinition: java.util.HashMap[String, SchemaQualifierDefinition], + value: String): Array[Byte] = { + + val columnDef = schemaMappingDefinition.get(columnName) + + if (columnDef == null) throw new Throwable("Unknown column:" + columnName) + else { + columnDef.columnSparkSqlType match { + case IntegerType => Bytes.toBytes(value.toInt) + case LongType => Bytes.toBytes(value.toLong) + case FloatType => Bytes.toBytes(value.toFloat) + case DoubleType => Bytes.toBytes(value.toDouble) + case StringType => Bytes.toBytes(value) + case TimestampType => Bytes.toBytes(value.toLong) + case _ => Bytes.toBytes(value) + } + } + } + + class ExecutionRuleForUnitTesting(val columnFilterCollection: ColumnFilterCollection, + val requiredQualifierDefinitionArray: + mutable.MutableList[SchemaQualifierDefinition]) +} \ No newline at end of file diff --git a/hbase-spark/src/main/scala/org/apache/hadoop/hbase/spark/HBaseContext.scala b/hbase-spark/src/main/scala/org/apache/hadoop/hbase/spark/HBaseContext.scala index f060fea..ab4dbf5 100644 --- a/hbase-spark/src/main/scala/org/apache/hadoop/hbase/spark/HBaseContext.scala +++ b/hbase-spark/src/main/scala/org/apache/hadoop/hbase/spark/HBaseContext.scala @@ -64,6 +64,8 @@ class HBaseContext(@transient sc: SparkContext, val broadcastedConf = sc.broadcast(new SerializableWritable(config)) val credentialsConf = sc.broadcast(new SerializableWritable(job.getCredentials)) + LatestHBaseContextCache.latest = this + if (tmpHdfsConfgFile != null && config != null) { val fs = FileSystem.newInstance(config) val tmpPath = new Path(tmpHdfsConfgFile) @@ -568,3 +570,7 @@ class HBaseContext(@transient sc: SparkContext, private[spark] def fakeClassTag[T]: ClassTag[T] = ClassTag.AnyRef.asInstanceOf[ClassTag[T]] } + +object LatestHBaseContextCache { + var latest:HBaseContext = null +} diff --git a/hbase-spark/src/test/scala/org/apache/hadoop/hbase/spark/DefaultSourceSuite.scala b/hbase-spark/src/test/scala/org/apache/hadoop/hbase/spark/DefaultSourceSuite.scala new file mode 100644 index 0000000..67f7ff0 --- /dev/null +++ b/hbase-spark/src/test/scala/org/apache/hadoop/hbase/spark/DefaultSourceSuite.scala @@ -0,0 +1,262 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.hadoop.hbase.spark + +import org.apache.hadoop.hbase.client.{Put, ConnectionFactory} +import org.apache.hadoop.hbase.util.Bytes +import org.apache.hadoop.hbase.{TableName, HBaseTestingUtility} +import org.apache.spark.sql.{DataFrame, SQLContext} +import org.apache.spark.{SparkContext, Logging} +import org.scalatest.{BeforeAndAfterAll, BeforeAndAfterEach, FunSuite} + +class DefaultSourceSuite extends FunSuite with +BeforeAndAfterEach with BeforeAndAfterAll with Logging { + @transient var sc: SparkContext = null + var TEST_UTIL: HBaseTestingUtility = new HBaseTestingUtility + + val tableName = "t1" + val columnFamily = "c" + + var sqlContext:SQLContext = null + var df:DataFrame = null + + override def beforeAll() { + + TEST_UTIL.startMiniCluster + + logInfo(" - minicluster started") + try + TEST_UTIL.deleteTable(TableName.valueOf(tableName)) + catch { + case e: Exception => logInfo(" - no table " + tableName + " found") + + } + logInfo(" - creating table " + tableName) + TEST_UTIL.createTable(TableName.valueOf(tableName), Bytes.toBytes(columnFamily)) + logInfo(" - created table") + + sc = new SparkContext("local", "test") + + val connection = ConnectionFactory.createConnection(TEST_UTIL.getConfiguration) + val table = connection.getTable(TableName.valueOf("t1")) + + try { + var put = new Put(Bytes.toBytes("get1")) + put.addColumn(Bytes.toBytes(columnFamily), Bytes.toBytes("a"), Bytes.toBytes("foo1")) + put.addColumn(Bytes.toBytes(columnFamily), Bytes.toBytes("b"), Bytes.toBytes("1")) + table.put(put) + put = new Put(Bytes.toBytes("get2")) + put.addColumn(Bytes.toBytes(columnFamily), Bytes.toBytes("a"), Bytes.toBytes("foo2")) + put.addColumn(Bytes.toBytes(columnFamily), Bytes.toBytes("b"), Bytes.toBytes("4")) + table.put(put) + put = new Put(Bytes.toBytes("get3")) + put.addColumn(Bytes.toBytes(columnFamily), Bytes.toBytes("a"), Bytes.toBytes("foo3")) + put.addColumn(Bytes.toBytes(columnFamily), Bytes.toBytes("b"), Bytes.toBytes("8")) + table.put(put) + put = new Put(Bytes.toBytes("get4")) + put.addColumn(Bytes.toBytes(columnFamily), Bytes.toBytes("a"), Bytes.toBytes("foo4")) + put.addColumn(Bytes.toBytes(columnFamily), Bytes.toBytes("b"), Bytes.toBytes("10")) + table.put(put) + put = new Put(Bytes.toBytes("get5")) + put.addColumn(Bytes.toBytes(columnFamily), Bytes.toBytes("a"), Bytes.toBytes("foo5")) + put.addColumn(Bytes.toBytes(columnFamily), Bytes.toBytes("b"), Bytes.toBytes("8")) + table.put(put) + } finally { + table.close() + connection.close() + } + + new HBaseContext(sc, TEST_UTIL.getConfiguration) + sqlContext = new SQLContext(sc) + + df = sqlContext.load("org.apache.hadoop.hbase.spark", + Map("hbase.columns.mapping" -> "KEY_FIELD STRING :key, A_FIELD STRING c:a, B_FIELD STRING c:b,", + "hbase.table" -> "t1")) + + df.registerTempTable("hbaseTmp") + } + + override def afterAll() { + TEST_UTIL.deleteTable(TableName.valueOf(tableName)) + logInfo("shuting down minicluster") + TEST_UTIL.shutdownMiniCluster() + + sc.stop() + } + + + /** + * A example of query three fields and also only using points for the filter + */ + test("testPointOnlyRowKeyQuery") { + val results = sqlContext.sql("SELECT KEY_FIELD, B_FIELD, A_FIELD FROM hbaseTmp " + + "WHERE " + + "(KEY_FIELD = 'get1' or KEY_FIELD = 'get2' or KEY_FIELD = 'get3')").take(10) + + val executionRules = DefaultSourceStaticUtils.lastFiveExecutionRules.poll() + + assert(results.length == 3) + + assert(executionRules.columnFilterCollection.columnFilterMap.size == 1) + val keyFieldFilter = executionRules.columnFilterCollection.columnFilterMap.get("KEY_FIELD").get + assert(keyFieldFilter.ranges.length == 0) + assert(keyFieldFilter.points.length == 3) + assert(Bytes.toString(keyFieldFilter.points.head).equals("get1")) + assert(Bytes.toString(keyFieldFilter.points(1)).equals("get2")) + assert(Bytes.toString(keyFieldFilter.points(2)).equals("get3")) + + assert(executionRules.requiredQualifierDefinitionArray.length == 2) + } + + /** + * A example of a OR merge between to ranges the result is one range + * Also an example of less then and greater then + */ + test("testTwoRangeRowKeyQuery") { + val results = sqlContext.sql("SELECT KEY_FIELD, B_FIELD, A_FIELD FROM hbaseTmp " + + "WHERE " + + "( KEY_FIELD < 'get2' or KEY_FIELD > 'get3')").take(10) + + val executionRules = DefaultSourceStaticUtils.lastFiveExecutionRules.poll() + + assert(results.length == 3) + + assert(executionRules.columnFilterCollection.columnFilterMap.size == 1) + val keyFieldFilter = executionRules.columnFilterCollection.columnFilterMap.get("KEY_FIELD").get + assert(keyFieldFilter.ranges.length == 2) + assert(keyFieldFilter.points.length == 0) + + assert(executionRules.requiredQualifierDefinitionArray.length == 2) + } + + /** + * A example of a AND merge between to ranges the result is one range + * Also an example of less then and equal to and greater then and equal to + */ + test("testOneCombinedRangeRowKeyQuery") { + val results = sqlContext.sql("SELECT KEY_FIELD, B_FIELD, A_FIELD FROM hbaseTmp " + + "WHERE " + + "(KEY_FIELD <= 'get3' and KEY_FIELD >= 'get2')").take(10) + + val executionRules = DefaultSourceStaticUtils.lastFiveExecutionRules.poll() + + assert(results.length == 2) + + assert(executionRules.columnFilterCollection.columnFilterMap.size == 1) + val keyFieldFilter = executionRules.columnFilterCollection.columnFilterMap.get("KEY_FIELD").get + assert(keyFieldFilter.ranges.length == 1) + assert(keyFieldFilter.points.length == 0) + + assert(executionRules.requiredQualifierDefinitionArray.length == 2) + } + + /** + * Do a select with no filters + */ + test("testSelectOnlyQuery") { + + val results = df.select("KEY_FIELD").take(10) + assert(results.length == 5) + + + val executionRules = DefaultSourceStaticUtils.lastFiveExecutionRules.poll() + assert(executionRules.columnFilterCollection == null) + assert(executionRules.requiredQualifierDefinitionArray.length == 0) + + } + + /** + * A complex query with one point and one range for both the + * rowKey and the a column + */ + test("testSQLPointAndRangeCombo") { + val results = sqlContext.sql("SELECT KEY_FIELD FROM hbaseTmp " + + "WHERE " + + "(KEY_FIELD = 'get1' and B_FIELD < '3') or " + + "(KEY_FIELD >= 'get3' and B_FIELD = '8')").take(5) + + + assert(results.length == 3) + + val executionRules = DefaultSourceStaticUtils.lastFiveExecutionRules.poll() + + assert(executionRules.columnFilterCollection.columnFilterMap.size == 2) + val keyFieldFilter = executionRules.columnFilterCollection.columnFilterMap.get("KEY_FIELD").get + assert(keyFieldFilter.ranges.length == 1) + assert(keyFieldFilter.ranges.head.upperBound == null) + assert(Bytes.toString(keyFieldFilter.ranges.head.lowerBound).equals("get3")) + assert(keyFieldFilter.ranges.head.isLowerBoundEqualTo) + assert(keyFieldFilter.points.length == 1) + assert(Bytes.toString(keyFieldFilter.points.head).equals("get1")) + + val bFieldFilter = executionRules.columnFilterCollection.columnFilterMap.get("B_FIELD").get + assert(bFieldFilter.ranges.length == 1) + assert(bFieldFilter.ranges.head.lowerBound.length == 0) + assert(Bytes.toString(bFieldFilter.ranges.head.upperBound).equals("3")) + assert(!bFieldFilter.ranges.head.isUpperBoundEqualTo) + assert(bFieldFilter.points.length == 1) + assert(Bytes.toString(bFieldFilter.points.head).equals("8")) + + assert(executionRules.requiredQualifierDefinitionArray.length == 1) + assert(executionRules.requiredQualifierDefinitionArray.head.columnName.equals("B_FIELD")) + assert(executionRules.requiredQualifierDefinitionArray.head.columnFamily.equals("c")) + assert(executionRules.requiredQualifierDefinitionArray.head.qualifier.equals("b")) + } + + /** + * A complex query with two complex ranges that doesn't merge into one + */ + test("testTwoCompleteRangeNonMergeRowKeyQuery") { + + val results = sqlContext.sql("SELECT KEY_FIELD, B_FIELD, A_FIELD FROM hbaseTmp " + + "WHERE " + + "( KEY_FIELD >= 'get1' and KEY_FIELD <= 'get2') or" + + "( KEY_FIELD > 'get3' and KEY_FIELD <= 'get5')").take(10) + + val executionRules = DefaultSourceStaticUtils.lastFiveExecutionRules.poll() + assert(results.length == 4) + + assert(executionRules.columnFilterCollection.columnFilterMap.size == 1) + val keyFieldFilter = executionRules.columnFilterCollection.columnFilterMap.get("KEY_FIELD").get + assert(keyFieldFilter.ranges.length == 2) + assert(keyFieldFilter.points.length == 0) + + assert(executionRules.requiredQualifierDefinitionArray.length == 2) + } + + /** + * A complex query with two complex ranges that does merge into one + */ + test("testTwoCompleteRangeMergeRowKeyQuery") { + val results = sqlContext.sql("SELECT KEY_FIELD, B_FIELD, A_FIELD FROM hbaseTmp " + + "WHERE " + + "( KEY_FIELD >= 'get1' and KEY_FIELD <= 'get3') or" + + "( KEY_FIELD > 'get3' and KEY_FIELD <= 'get5')").take(10) + + val executionRules = DefaultSourceStaticUtils.lastFiveExecutionRules.poll() + + assert(results.length == 5) + + assert(executionRules.columnFilterCollection.columnFilterMap.size == 1) + val keyFieldFilter = executionRules.columnFilterCollection.columnFilterMap.get("KEY_FIELD").get + assert(keyFieldFilter.ranges.length == 1) + assert(keyFieldFilter.points.length == 0) + + assert(executionRules.requiredQualifierDefinitionArray.length == 2) + } +}