diff --git a/hbase-spark/pom.xml b/hbase-spark/pom.xml
index e48f9e8..7417127 100644
--- a/hbase-spark/pom.xml
+++ b/hbase-spark/pom.xml
@@ -79,6 +79,12 @@
org.apache.spark
+ spark-sql_${scala.binary.version}
+ ${spark.version}
+ provided
+
+
+ org.apache.spark
spark-streaming_${scala.binary.version}
${spark.version}
diff --git a/hbase-spark/src/main/java/org/apache/hadoop/hbase/spark/SparkSQLPushDownFilter.java b/hbase-spark/src/main/java/org/apache/hadoop/hbase/spark/SparkSQLPushDownFilter.java
new file mode 100644
index 0000000..01e73b9
--- /dev/null
+++ b/hbase-spark/src/main/java/org/apache/hadoop/hbase/spark/SparkSQLPushDownFilter.java
@@ -0,0 +1,160 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.hadoop.hbase.spark;
+
+import org.apache.commons.logging.Log;
+import org.apache.commons.logging.LogFactory;
+import org.apache.hadoop.hbase.Cell;
+import org.apache.hadoop.hbase.exceptions.DeserializationException;
+import org.apache.hadoop.hbase.filter.Filter;
+import org.apache.hadoop.hbase.filter.FilterBase;
+
+import java.io.ByteArrayInputStream;
+import java.io.ByteArrayOutputStream;
+import java.io.IOException;
+import java.io.ObjectInput;
+import java.io.ObjectInputStream;
+import java.io.ObjectOutput;
+import java.io.ObjectOutputStream;
+import java.io.Serializable;
+import java.util.HashMap;
+
+/**
+ * This filter will push down all qualifier logic given to us
+ * by SparkSQL so that we have make the filters at the region server level
+ * and avoid sending the data back to the client to be filtered.
+ */
+public class SparkSQLPushDownFilter extends FilterBase implements Serializable{
+ protected static final Log log = LogFactory.getLog(SparkSQLPushDownFilter.class);
+
+ HashMap columnFamilyQualifierFilterMap;
+
+ private Filter filter;
+
+ public SparkSQLPushDownFilter(Filter filter) {
+ this.filter = filter;
+ }
+
+ public Filter getFilter() {
+ return filter;
+ }
+
+ public SparkSQLPushDownFilter(HashMap columnFamilyQualifierFilterMap) {
+ this.columnFamilyQualifierFilterMap = columnFamilyQualifierFilterMap;
+ }
+
+ /**
+ * This method will find the related filter logic for the given
+ * column family and qualifier then execute it. It will also
+ * not clone the in coming cell to avoid extra object creation
+ *
+ * @param c The Cell to be validated
+ * @return ReturnCode object to determine if skipping is required
+ * @throws IOException
+ */
+ @Override
+ public ReturnCode filterKeyValue(Cell c) throws IOException {
+
+ //Get filter if one exist
+ ColumnFilter filter =
+ columnFamilyQualifierFilterMap.get(new ColumnFamilyQualifierMapKeyWrapper(
+ c.getFamilyArray(),
+ c.getFamilyOffset(),
+ c.getFamilyLength(),
+ c.getQualifierArray(),
+ c.getQualifierOffset(),
+ c.getQualifierLength()));
+
+ if (filter == null) {
+ //If no filter then just include values
+ return ReturnCode.INCLUDE;
+ } else {
+ //If there is a filter then run validation
+ if (filter.validate(c.getValueArray(), c.getValueOffset(), c.getValueLength())) {
+ return ReturnCode.INCLUDE;
+ } else {
+ //If validation fails then skill whole row
+ return ReturnCode.NEXT_ROW;
+ }
+ }
+ }
+
+ /**
+ * Used to construct the object from a byte array on the Region Server side
+ *
+ * @param bytes object represented as a byte array
+ * @return a new PushDownFilter object
+ * @throws DeserializationException
+ */
+ public static SparkSQLPushDownFilter parseFrom(final byte [] bytes)
+ throws DeserializationException {
+
+ SparkSQLPushDownFilter result;
+
+ ByteArrayInputStream bis = new ByteArrayInputStream(bytes);
+ ObjectInput in = null;
+ try {
+ in = new ObjectInputStream(bis);
+ result = (SparkSQLPushDownFilter)in.readObject();
+ } catch (Exception e) {
+ throw new DeserializationException(e);
+ } finally {
+ try {
+ bis.close();
+ } catch (IOException ex) {
+ log.error("Error while trying to parseFrom method: ", ex);
+ }
+ try {
+ if (in != null) {
+ in.close();
+ }
+ } catch (IOException ex) {
+ log.error("Error while trying to parseFrom method: ", ex);
+ }
+ }
+ return result;
+ }
+
+ /**
+ * @return The filter serialized using pb
+ */
+ public byte [] toByteArray() throws IOException{
+ ByteArrayOutputStream bos = new ByteArrayOutputStream();
+ ObjectOutput out = null;
+ try {
+ out = new ObjectOutputStream(bos);
+ out.writeObject(this);
+ return bos.toByteArray();
+ } finally {
+ try {
+ if (out != null) {
+ out.close();
+ }
+ } catch (IOException ex) {
+ // ignore close exception
+ }
+ try {
+ bos.close();
+ } catch (IOException ex) {
+ // ignore close exception
+ }
+ }
+ }
+
+}
diff --git a/hbase-spark/src/main/scala/org/apache/hadoop/hbase/spark/ColumnFamilyQualifierMapKeyWrapper.scala b/hbase-spark/src/main/scala/org/apache/hadoop/hbase/spark/ColumnFamilyQualifierMapKeyWrapper.scala
new file mode 100644
index 0000000..ea6675a
--- /dev/null
+++ b/hbase-spark/src/main/scala/org/apache/hadoop/hbase/spark/ColumnFamilyQualifierMapKeyWrapper.scala
@@ -0,0 +1,61 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.hadoop.hbase.spark
+
+import org.apache.hadoop.hbase.util.Bytes
+
+/**
+ * A wrapper class that will allow both columnFamily and qualifier to
+ * be the key of a hashMap. Also allow for finding the value in a hashmap
+ * with out cloning the HBase value from the HBase Cell object
+ * @param columnFamily ColumnFamily byte array
+ * @param columnFamilyOffSet Offset of columnFamily value in the array
+ * @param columnFamilyLength Length of the columnFamily value in the columnFamily array
+ * @param qualifier Qualifier byte array
+ * @param qualifierOffSet Offset of qualifier value in the array
+ * @param qualifierLength Length of the qualifier value with in the array
+ */
+class ColumnFamilyQualifierMapKeyWrapper(val columnFamily:Array[Byte],
+ val columnFamilyOffSet:Int,
+ val columnFamilyLength:Int,
+ val qualifier:Array[Byte],
+ val qualifierOffSet:Int,
+ val qualifierLength:Int)
+ extends Serializable{
+
+ override def equals(other:Any): Boolean = {
+ val otherWrapper = other.asInstanceOf[ColumnFamilyQualifierMapKeyWrapper]
+
+ Bytes.compareTo(columnFamily,
+ columnFamilyOffSet,
+ columnFamilyLength,
+ otherWrapper.columnFamily,
+ otherWrapper.columnFamilyOffSet,
+ otherWrapper.columnFamilyLength) == 0 && Bytes.compareTo(qualifier,
+ qualifierOffSet,
+ qualifierLength,
+ otherWrapper.qualifier,
+ otherWrapper.qualifierOffSet,
+ otherWrapper.qualifierLength) == 0
+ }
+
+ override def hashCode():Int = {
+ Bytes.hashCode(columnFamily, columnFamilyOffSet, columnFamilyLength) +
+ Bytes.hashCode(qualifier, qualifierOffSet, qualifierLength)
+ }
+}
diff --git a/hbase-spark/src/main/scala/org/apache/hadoop/hbase/spark/DefaultSource.scala b/hbase-spark/src/main/scala/org/apache/hadoop/hbase/spark/DefaultSource.scala
new file mode 100644
index 0000000..8d6fb4d
--- /dev/null
+++ b/hbase-spark/src/main/scala/org/apache/hadoop/hbase/spark/DefaultSource.scala
@@ -0,0 +1,926 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.hadoop.hbase.spark
+
+import java.sql.Timestamp
+import java.util
+import java.util.concurrent.ConcurrentLinkedQueue
+
+import org.apache.hadoop.hbase.client.{ConnectionFactory, Get, Result, Scan}
+import org.apache.hadoop.hbase.util.Bytes
+import org.apache.hadoop.hbase.{TableName, HBaseConfiguration}
+import org.apache.spark.Logging
+import org.apache.spark.rdd.RDD
+import org.apache.spark.sql.{Row, SQLContext}
+import org.apache.spark.sql.sources._
+import org.apache.spark.sql.types._
+
+import scala.collection.mutable
+
+/**
+ * DefaultSource for integration with Spark's dataframe datasources.
+ * This class with produce a relationProvider based on input give to it from spark
+ *
+ * In all this DefaultSource support the following datasource functionality
+ * - Scan range pruning through filter push down logic based on rowKeys
+ * - Filter push down logic on columns that are not rowKey columns
+ * - Qualifier filtering based on columns used in the SparkSQL statement
+ * - Type conversions of basic SQL types
+ */
+class DefaultSource extends RelationProvider {
+
+ val TABLE_KEY:String = "hbase.table"
+ val SCHEMA_COLUMNS_MAPPING_KEY:String = "hbase.columns.mapping"
+ val BATCHING_NUM_KEY:String = "hbase.batching.num"
+ val CACHING_NUM_KEY:String = "hbase.caching.num"
+ val HBASE_CONFIG_RESOURCES_LOCATIONS:String = "hbase.config.resources"
+ val USE_HBASE_CONTEXT:String = "hbase.use.hbase.context"
+
+ /**
+ * Is given input from SparkSQL to construct a BaseRelation
+ * @param sqlContext SparkSQL context
+ * @param parameters Parameters given to us from SparkSQL
+ * @return A BaseRelation Object
+ */
+ override def createRelation(sqlContext: SQLContext,
+ parameters: Map[String, String]):
+ BaseRelation = {
+
+
+ val tableName = parameters.get(TABLE_KEY)
+ if (tableName.isEmpty)
+ new Throwable("Invalid value for " + TABLE_KEY +" '" + tableName + "'")
+
+ val schemaMappingString = parameters.getOrElse(SCHEMA_COLUMNS_MAPPING_KEY, "")
+ val batchingNumStr = parameters.getOrElse(BATCHING_NUM_KEY, "1000")
+ val cachingNumStr = parameters.getOrElse(CACHING_NUM_KEY, "1000")
+ val hbaseConfigResources = parameters.getOrElse(HBASE_CONFIG_RESOURCES_LOCATIONS, "")
+ val useHBaseReources = parameters.getOrElse(USE_HBASE_CONTEXT, "true")
+
+ val batchingNum:Int = try {
+ batchingNumStr.toInt
+ } catch {
+ case e:Exception => throw
+ new Throwable("Invalid value for " + BATCHING_NUM_KEY +" '" + batchingNumStr + "'", e )
+ }
+
+ val cachingNum:Int = try {
+ cachingNumStr.toInt
+ } catch {
+ case e:Exception => throw
+ new Throwable("Invalid value for " + CACHING_NUM_KEY +" '" + cachingNumStr + "'", e )
+ }
+
+ new HBaseRelation(tableName.get,
+ generateSchemaMappingMap(schemaMappingString),
+ batchingNum.toInt,
+ cachingNum.toInt,
+ hbaseConfigResources,
+ useHBaseReources.equalsIgnoreCase("true"))(sqlContext)
+ }
+
+ /**
+ * Reads the SCHEMA_COLUMNS_MAPPING_KEY and converts it to a map of
+ * SchemaQualifierDefinitions with the original sql column name as the key
+ * @param schemaMappingString The schema mapping string from the SparkSQL map
+ * @return A map of definitions keyed by the SparkSQL column name
+ */
+ def generateSchemaMappingMap(schemaMappingString:String):
+ java.util.HashMap[String, SchemaQualifierDefinition] = {
+ try {
+ val columnDefinitions = schemaMappingString.split(',')
+ val resultingMap = new java.util.HashMap[String, SchemaQualifierDefinition]()
+ columnDefinitions.map(cd => {
+ val parts = cd.trim.split(' ')
+
+ //Make sure we get three parts
+ //
+ if (parts.length == 3) {
+ val hbaseDefinitionParts = if (parts(2).charAt(0) == ':') {
+ Array[String]("", "key")
+ } else {
+ parts(2).split(':')
+ }
+ resultingMap.put(parts(0), new SchemaQualifierDefinition(parts(0),
+ parts(1), hbaseDefinitionParts(0), hbaseDefinitionParts(1)))
+ } else {
+ throw new Throwable("Invalid value for schema mapping '" + cd +
+ "' should be ' :' " +
+ "for columns and ' :' for rowKeys")
+ }
+ })
+ resultingMap
+ } catch {
+ case e:Exception => throw
+ new Throwable("Invalid value for " + SCHEMA_COLUMNS_MAPPING_KEY +
+ " '" + schemaMappingString + "'", e )
+ }
+ }
+}
+
+/**
+ * Implementation of Spark BaseRelation that will build up our scan logic
+ * , do the scan pruning, filter push down, and value conversions
+ *
+ * @param tableName HBase table that we plan to read from
+ * @param schemaMappingDefinition SchemaMapping information to map HBase
+ * Qualifiers to SparkSQL columns
+ * @param batchingNum The batching number to be applied to the
+ * scan object
+ * @param cachingNum The caching number to be applied to the
+ * scan object
+ * @param configResources Optional comma separated list of config resources
+ * to get based on their URI
+ * @param useHBaseContext If true this will look to see if
+ * HBaseContext.latest is populated to use that
+ * connection information
+ * @param sqlContext SparkSQL context
+ */
+class HBaseRelation (val tableName:String,
+ val schemaMappingDefinition:java.util.HashMap[String, SchemaQualifierDefinition],
+ val batchingNum:Int,
+ val cachingNum:Int,
+ val configResources:String,
+ val useHBaseContext:Boolean) (
+ @transient val sqlContext:SQLContext)
+ extends BaseRelation with PrunedFilteredScan with Logging {
+
+ //create or get latest HBaseContext
+ @transient val hbaseContext:HBaseContext = if (useHBaseContext) {
+ LatestHBaseContextCache.latest
+ } else {
+ val config = HBaseConfiguration.create()
+ configResources.split(",").foreach( r => config.addResource(r))
+ new HBaseContext(sqlContext.sparkContext, config)
+ }
+
+ /**
+ * Generates a Spark SQL schema object so Spark SQL knows what is being
+ * provided by this BaseRelation
+ *
+ * @return schema generated from the SCHEMA_COLUMNS_MAPPING_KEY value
+ */
+ override def schema: StructType = {
+
+ val metadataBuilder = new MetadataBuilder()
+
+ val structFieldArray = new Array[StructField](schemaMappingDefinition.size())
+
+ val schemaMappingDefinitionIt = schemaMappingDefinition.values().iterator()
+ var indexCounter = 0
+ while (schemaMappingDefinitionIt.hasNext) {
+ val c = schemaMappingDefinitionIt.next()
+
+ val metadata = metadataBuilder.putString("name", c.columnName).build()
+ val struckField =
+ new StructField(c.columnName, c.columnSparkSqlType, nullable = true, metadata)
+
+ structFieldArray(indexCounter) = struckField
+ indexCounter += 1
+ }
+
+ val result = new StructType(structFieldArray)
+ //TODO push schema to listener
+ result
+ }
+
+ /**
+ * Here we are building the functionality to populate the resulting RDD[Row]
+ * Here is where we will do the following:
+ * - Filter push down
+ * - Scan or GetList pruning
+ * - Executing our scan(s) or/and GetList to generate result
+ *
+ * @param requiredColumns The columns that are being requested by the requesting query
+ * @param filters The filters that are being applied by the requesting query
+ * @return RDD will all the results from HBase needed for SparkSQL to
+ * execute the query on
+ */
+ override def buildScan(requiredColumns: Array[String], filters: Array[Filter]): RDD[Row] = {
+
+ val columnFilterCollection = buildColumnFilterCollection(filters)
+
+ val requiredQualifierDefinitionArray = new mutable.MutableList[SchemaQualifierDefinition]
+ requiredColumns.foreach( c => {
+ val definition = schemaMappingDefinition.get(c)
+ if (definition.columnFamilyBytes.length > 0) {
+ requiredQualifierDefinitionArray += definition
+ }
+ })
+
+ //Create a local variable so that scala doesn't have to
+ // serialize the whole HBaseRelation Object
+ val serializableDefinitionMap = schemaMappingDefinition
+
+
+ //retain the information for unit testing checks
+ DefaultSourceStaticUtils.populateLatestExecutionRules(columnFilterCollection,
+ requiredQualifierDefinitionArray)
+
+ var resultRDD: RDD[Row] = null
+
+ if (columnFilterCollection != null) {
+ val pushDownFilterJava =
+ new SparkSQLPushDownFilter(
+ columnFilterCollection.generateFamilyQualifiterFilterMap(schemaMappingDefinition))
+
+ val getList = new util.ArrayList[Get]()
+ val rddList = new util.ArrayList[RDD[Row]]()
+
+ val it = columnFilterCollection.columnFilterMap.iterator
+ while (it.hasNext) {
+ val e = it.next()
+ val columnDefinition = schemaMappingDefinition.get(e._1)
+ //check is a rowKey
+ if (columnDefinition != null && columnDefinition.columnFamily.isEmpty) {
+ //add points to getList
+ e._2.points.foreach(p => {
+ val get = new Get(p)
+ requiredQualifierDefinitionArray.foreach( d => get.addColumn(d.columnFamilyBytes, d.qualifierBytes))
+ getList.add(get)
+ })
+
+ val rangeIt = e._2.ranges.iterator
+
+ while (rangeIt.hasNext) {
+ val r = rangeIt.next()
+
+ val scan = new Scan()
+ scan.setBatch(batchingNum)
+ scan.setCaching(cachingNum)
+ requiredQualifierDefinitionArray.foreach( d => scan.addColumn(d.columnFamilyBytes, d.qualifierBytes))
+
+ if (pushDownFilterJava.columnFamilyQualifierFilterMap.size() > 0) {
+ scan.setFilter(pushDownFilterJava)
+ }
+
+ //Check if there is a lower bound
+ if (r.lowerBound != null && r.lowerBound.length > 0) {
+
+ if (r.isLowerBoundEqualTo) {
+ //HBase startRow is inclusive: Therefore it acts like isLowerBoundEqualTo
+ // by default
+ scan.setStartRow(r.lowerBound)
+ } else {
+ //Since we don't equalTo we want the next value we need
+ // to add another byte to the start key. That new byte will be
+ // the min byte value.
+ val newArray = new Array[Byte](r.lowerBound.length + 1)
+ System.arraycopy(r.lowerBound, 0, newArray, 0, r.lowerBound.length)
+
+ //new Min Byte
+ newArray(r.lowerBound.length) = Byte.MinValue
+ scan.setStartRow(newArray)
+ }
+ }
+
+ //Check if there is a upperBound
+ if (r.upperBound != null && r.upperBound.length > 0) {
+ if (r.isUpperBoundEqualTo) {
+ //HBase stopRow is exclusive: therefore it DOESN'T ast like isUpperBoundEqualTo
+ // by default. So we need to add a new max byte to the stopRow key
+ val newArray = new Array[Byte](r.upperBound.length + 1)
+ System.arraycopy(r.upperBound, 0, newArray, 0, r.upperBound.length)
+
+ //New Max Bytes
+ newArray(r.upperBound.length) = Byte.MaxValue
+
+ scan.setStopRow(newArray)
+ } else {
+ //Here equalTo is false for Upper bound which is exclusive and
+ // HBase stopRow acts like that by default so no need to mutate the
+ // rowKey
+ scan.setStopRow(r.upperBound)
+ }
+ }
+
+ val rdd = hbaseContext.hbaseRDD(TableName.valueOf(tableName), scan).map(r => {
+ Row.fromSeq(requiredColumns.map(c =>
+ DefaultSourceStaticUtils.getValue(c, serializableDefinitionMap, r._2)))
+ })
+ rddList.add(rdd)
+ }
+ }
+ }
+
+ //If there is more then one RDD then we have to union them together
+ for (i <- 0 until rddList.size()) {
+ if (resultRDD == null) resultRDD = rddList.get(i)
+ else resultRDD = resultRDD.union(rddList.get(i))
+
+ }
+
+ //If there are gets then we can get them from the driver and union that rdd in
+ // with the rest of the values.
+ if (getList.size() > 0) {
+ val connection = ConnectionFactory.createConnection(hbaseContext.tmpHdfsConfiguration)
+ val table = connection.getTable(TableName.valueOf(tableName))
+ try {
+ val results = table.get(getList)
+ val rowList = mutable.MutableList[Row]()
+ for (i <- 0 until results.length) {
+ val rowArray = requiredColumns.map(c =>
+ DefaultSourceStaticUtils.getValue(c, schemaMappingDefinition, results(i)))
+ rowList += Row.fromSeq(rowArray)
+ }
+ val getRDD = sqlContext.sparkContext.parallelize(rowList)
+ if (resultRDD == null) resultRDD = getRDD
+ else {
+ resultRDD = resultRDD.union(getRDD)
+ }
+ } finally {
+ try {
+ connection.close()
+ } finally {
+ table.close()
+ }
+ }
+
+ }
+ }
+ if (resultRDD == null) {
+ val scan = new Scan()
+ scan.setBatch(batchingNum)
+ scan.setCaching(cachingNum)
+ requiredQualifierDefinitionArray.foreach( d => scan.addColumn(d.columnFamilyBytes, d.qualifierBytes))
+
+ val rdd = hbaseContext.hbaseRDD(TableName.valueOf(tableName), scan).map(r => {
+ Row.fromSeq(requiredColumns.map(c => DefaultSourceStaticUtils.getValue(c,
+ serializableDefinitionMap, r._2)))
+ })
+ resultRDD=rdd
+ }
+ resultRDD
+ }
+
+ /**
+ * Root recursive function that will loop over the filters provided by
+ * SparkSQL. Some filters are AND or OR functions and contain additional filters
+ * hence the need for recursion.
+ *
+ * @param filters Filters provided by SparkSQL.
+ * Filters are joined with the AND operater
+ * @return A ColumnFilterCollection whish is a consolidated construct to
+ * hold the high level filter information
+ */
+ def buildColumnFilterCollection(filters: Array[Filter]): ColumnFilterCollection = {
+ var superCollection: ColumnFilterCollection = null
+
+ filters.foreach( f => {
+ val parentCollection = new ColumnFilterCollection
+ buildColumnFilterCollection(parentCollection, f)
+ if (superCollection == null) superCollection = parentCollection
+ else {superCollection.mergeIntersect(parentCollection)}
+ })
+ superCollection
+ }
+
+ /**
+ * Recursive function that will work to convert Spark Filter objects to ColumnFilterCollection
+ *
+ * @param parentFilterCollection Parent ColumnFilterCollection
+ * @param filter Current given filter from SparkSQL
+ */
+ def buildColumnFilterCollection(parentFilterCollection:ColumnFilterCollection,
+ filter:Filter): Unit = {
+ filter match {
+
+ case EqualTo(attr, value) =>
+ parentFilterCollection.mergeUnion(attr,
+ new ColumnFilter(DefaultSourceStaticUtils.getByteValue(attr,
+ schemaMappingDefinition, value.toString)))
+
+ case LessThan(attr, value) =>
+ parentFilterCollection.mergeUnion(attr, new ColumnFilter(null,
+ new ScanRange(DefaultSourceStaticUtils.getByteValue(attr,
+ schemaMappingDefinition, value.toString), false,
+ new Array[Byte](0), true)))
+
+ case GreaterThan(attr, value) =>
+ parentFilterCollection.mergeUnion(attr, new ColumnFilter(null,
+ new ScanRange(null, true, DefaultSourceStaticUtils.getByteValue(attr,
+ schemaMappingDefinition, value.toString), false)))
+
+ case LessThanOrEqual(attr, value) =>
+ parentFilterCollection.mergeUnion(attr, new ColumnFilter(null,
+ new ScanRange(DefaultSourceStaticUtils.getByteValue(attr,
+ schemaMappingDefinition, value.toString), true,
+ new Array[Byte](0), true)))
+
+ case GreaterThanOrEqual(attr, value) =>
+ parentFilterCollection.mergeUnion(attr, new ColumnFilter(null,
+ new ScanRange(null, true, DefaultSourceStaticUtils.getByteValue(attr,
+ schemaMappingDefinition, value.toString), true)))
+
+ case Or(left, right) =>
+ buildColumnFilterCollection(parentFilterCollection, left)
+ val rightSideCollection = new ColumnFilterCollection
+ buildColumnFilterCollection(rightSideCollection, right)
+ parentFilterCollection.mergeUnion(rightSideCollection)
+ case And(left, right) =>
+ buildColumnFilterCollection(parentFilterCollection, left)
+ val rightSideCollection = new ColumnFilterCollection
+ buildColumnFilterCollection(rightSideCollection, right)
+ parentFilterCollection.mergeIntersect(rightSideCollection)
+ case _ => //nothing
+ }
+ }
+}
+
+/**
+ * Construct to contains column data that spend SparkSQL and HBase
+ *
+ * @param columnName SparkSQL column name
+ * @param colType SparkSQL column type
+ * @param columnFamily HBase column family
+ * @param qualifier HBase qualifier name
+ */
+case class SchemaQualifierDefinition(columnName:String,
+ colType:String,
+ columnFamily:String,
+ qualifier:String) extends Serializable {
+ val columnFamilyBytes = Bytes.toBytes(columnFamily)
+ val qualifierBytes = Bytes.toBytes(qualifier)
+ val columnSparkSqlType = if (colType.equals("BOOLEAN")) BooleanType
+ else if (colType.equals("TINYINT")) IntegerType
+ else if (colType.equals("INT")) IntegerType
+ else if (colType.equals("BIGINT")) LongType
+ else if (colType.equals("FLOAT")) FloatType
+ else if (colType.equals("DOUBLE")) DoubleType
+ else if (colType.equals("STRING")) StringType
+ else if (colType.equals("TIMESTAMP")) TimestampType
+ else if (colType.equals("DECIMAL")) StringType //DataTypes.createDecimalType(precision, scale)
+ else throw new Throwable("Unsupported column type :" + colType)
+}
+
+/**
+ * Construct to contain a single scan ranges information. Also
+ * provide functions to merge with other scan ranges through AND
+ * or OR operators
+ *
+ * @param upperBound Upper bound of scan
+ * @param isUpperBoundEqualTo Include upper bound value in the results
+ * @param lowerBound Lower bound of scan
+ * @param isLowerBoundEqualTo Include lower bound value in the results
+ */
+class ScanRange(var upperBound:Array[Byte], var isUpperBoundEqualTo:Boolean,
+ var lowerBound:Array[Byte], var isLowerBoundEqualTo:Boolean)
+ extends Serializable {
+
+ /**
+ * Function to merge another scan object through a AND operation
+ * @param other Other scan object
+ */
+ def mergeIntersect(other:ScanRange): Unit = {
+ val upperBoundCompare = compareRange(upperBound, other.upperBound)
+ val lowerBoundCompare = compareRange(lowerBound, other.lowerBound)
+
+ upperBound = if (upperBoundCompare <0) upperBound else other.upperBound
+ lowerBound = if (lowerBoundCompare >0) lowerBound else other.lowerBound
+
+ isLowerBoundEqualTo = if (lowerBoundCompare == 0)
+ isLowerBoundEqualTo && other.isLowerBoundEqualTo
+ else isLowerBoundEqualTo
+
+ isUpperBoundEqualTo = if (upperBoundCompare == 0)
+ isUpperBoundEqualTo && other.isUpperBoundEqualTo
+ else isUpperBoundEqualTo
+ }
+
+ /**
+ * Function to merge another scan object through a OR operation
+ * @param other Other scan object
+ */
+ def mergeUnion(other:ScanRange): Unit = {
+
+ val upperBoundCompare = compareRange(upperBound, other.upperBound)
+ val lowerBoundCompare = compareRange(lowerBound, other.lowerBound)
+
+ upperBound = if (upperBoundCompare >0) upperBound else other.upperBound
+ lowerBound = if (lowerBoundCompare <0) lowerBound else other.lowerBound
+
+ isLowerBoundEqualTo = if (lowerBoundCompare == 0)
+ isLowerBoundEqualTo || other.isLowerBoundEqualTo
+ else isLowerBoundEqualTo
+
+ isUpperBoundEqualTo = if (upperBoundCompare == 0)
+ isUpperBoundEqualTo || other.isUpperBoundEqualTo
+ else isUpperBoundEqualTo
+ }
+
+ /**
+ * Common function to see if this scan over laps with another
+ *
+ * Reference Visual
+ *
+ * A B
+ * |---------------------------|
+ * LL--------------LU
+ * RL--------------RU
+ *
+ * A = lowest value is byte[0]
+ * B = highest value is null
+ * LL = Left Lower Bound
+ * LU = Left Upper Bound
+ * RL = Right Lower Bound
+ * RU = Right Upper Bound
+ *
+ * @param other Other scan object
+ * @return True is overlap false is not overlap
+ */
+ def doesOverLap(other:ScanRange): Boolean = {
+
+ var leftRange:ScanRange = null
+ var rightRange:ScanRange = null
+
+ //First identify the Left range
+ // Also lower bound can't be null
+ if (Bytes.compareTo(lowerBound, other.lowerBound) <=0) {
+ leftRange = this
+ rightRange = other
+ } else {
+ leftRange = other
+ rightRange = this
+ }
+
+ //Then see if leftRange goes to null or if leftRange.upperBound
+ // upper is greater or equals to rightRange.lowerBound
+ leftRange.upperBound == null ||
+ Bytes.compareTo(leftRange.upperBound, rightRange.lowerBound) >= 0
+ }
+
+ /**
+ * Special compare logic because we can have null values
+ * for left or right bound
+ *
+ * @param left Left byte array
+ * @param right Right byte array
+ * @return 0 for equals 1 is left is greater and -1 is right is greater
+ */
+ def compareRange(left:Array[Byte], right:Array[Byte]): Int = {
+ if (left == null && right == null) 0
+ else if (left == null && right != null) 1
+ else if (left != null && right == null) -1
+ else Bytes.compareTo(left, right)
+ }
+
+ override def toString:String = {
+ "ScanRange:(" + Bytes.toString(upperBound) + "," + isUpperBoundEqualTo + "," +
+ Bytes.toString(lowerBound) + "," + isLowerBoundEqualTo + ")"
+ }
+}
+
+/**
+ * Contains information related to a filters for a given column.
+ * This can contain many ranges or points.
+ *
+ * @param currentPoint the initial point when the filter is created
+ * @param currentRange the initial scanRange when the filter is created
+ */
+class ColumnFilter (currentPoint:Array[Byte] = null,
+ currentRange:ScanRange = null) extends Serializable {
+ //Collection of ranges
+ var ranges = new mutable.MutableList[ScanRange]()
+ if (currentRange != null ) ranges.+=(currentRange)
+
+ //Collection of points
+ var points = new mutable.MutableList[Array[Byte]]()
+ if (currentPoint != null) {
+ points.+=(currentPoint)
+ }
+
+ /**
+ * This will validate a give value through the filter's points and/or ranges
+ * the result will be if the value passed the filter
+ *
+ * @param value Value to be validated
+ * @param valueOffSet The offset of the value
+ * @param valueLength The length of the value
+ * @return True is the value passes the filter false if not
+ */
+ def validate(value:Array[Byte], valueOffSet:Int, valueLength:Int):Boolean = {
+ var result = false
+
+ points.foreach( p => {
+ if (Bytes.equals(p, 0, p.length, value, valueOffSet, valueLength)) {
+ result = true
+ }
+ })
+
+ ranges.foreach( r => {
+ val upperBoundPass = r.upperBound == null ||
+ (r.isUpperBoundEqualTo &&
+ Bytes.compareTo(r.upperBound, 0, r.upperBound.length,
+ value, valueOffSet, valueLength) >= 0) ||
+ (!r.isUpperBoundEqualTo &&
+ Bytes.compareTo(r.upperBound, 0, r.upperBound.length,
+ value, valueOffSet, valueLength) > 0)
+
+ val lowerBoundPass = r.lowerBound == null || r.lowerBound.length == 0
+ (r.isLowerBoundEqualTo &&
+ Bytes.compareTo(r.lowerBound, 0, r.lowerBound.length,
+ value, valueOffSet, valueLength) <= 0) ||
+ (!r.isLowerBoundEqualTo &&
+ Bytes.compareTo(r.lowerBound, 0, r.lowerBound.length,
+ value, valueOffSet, valueLength) < 0)
+
+ result = result || (upperBoundPass && lowerBoundPass)
+ })
+ result
+ }
+
+ /**
+ * This will allow us to merge filter logic that is joined to the existing filter
+ * through a OR operator
+ *
+ * @param other Filter to merge
+ */
+ def mergeUnion(other:ColumnFilter): Unit = {
+ other.points.foreach( p => points += p)
+
+ other.ranges.foreach( otherR => {
+ var doesOverLap = false
+ ranges.foreach{ r =>
+ if (r.doesOverLap(otherR)) {
+ r.mergeUnion(otherR)
+ doesOverLap = true
+ }}
+ if (!doesOverLap) ranges.+=(otherR)
+ })
+ }
+
+ /**
+ * This will allow us to merge filter logic that is joined to the existing filter
+ * through a AND operator
+ *
+ * @param other Filter to merge
+ */
+ def mergeIntersect(other:ColumnFilter): Unit = {
+ val survivingPoints = new mutable.MutableList[Array[Byte]]()
+ points.foreach( p => {
+ other.points.foreach( otherP => {
+ if (Bytes.equals(p, otherP)) {
+ survivingPoints.+=(p)
+ }
+ })
+ })
+ points = survivingPoints
+
+ val survivingRanges = new mutable.MutableList[ScanRange]()
+
+ other.ranges.foreach( otherR => {
+ ranges.foreach( r => {
+ if (r.doesOverLap(otherR)) {
+ r.mergeIntersect(otherR)
+ survivingRanges += r
+ }
+ })
+ })
+ ranges = survivingRanges
+ }
+
+ override def toString:String = {
+ val strBuilder = new StringBuilder
+ strBuilder.append("(points:(")
+ var isFirst = true
+ points.foreach( p => {
+ if (isFirst) isFirst = false
+ else strBuilder.append(",")
+ strBuilder.append(Bytes.toString(p))
+ })
+ strBuilder.append("),ranges:")
+ isFirst = true
+ ranges.foreach( r => {
+ if (isFirst) isFirst = false
+ else strBuilder.append(",")
+ strBuilder.append(r)
+ })
+ strBuilder.append("))")
+ strBuilder.toString()
+ }
+}
+
+/**
+ * A collection of ColumnFilters indexed by column names.
+ *
+ * Also contains merge commends that will consolidate the filters
+ * per column name
+ */
+class ColumnFilterCollection {
+ val columnFilterMap = new mutable.HashMap[String, ColumnFilter]
+
+ def clear(): Unit = {
+ columnFilterMap.clear()
+ }
+
+ /**
+ * This will allow us to merge filter logic that is joined to the existing filter
+ * through a OR operator. This will merge a single columns filter
+ *
+ * @param column The column to be merged
+ * @param other The other ColumnFilter object to merge
+ */
+ def mergeUnion(column:String, other:ColumnFilter): Unit = {
+ val existingFilter = columnFilterMap.get(column)
+ if (existingFilter.isEmpty) {
+ columnFilterMap.+=((column, other))
+ } else {
+ existingFilter.get.mergeUnion(other)
+ }
+ }
+
+ /**
+ * This will allow us to merge all filters in the existing collection
+ * to the filters in the other collection. All merges are done as a result
+ * of a OR operator
+ *
+ * @param other The other Column Filter Collection to be merged
+ */
+ def mergeUnion(other:ColumnFilterCollection): Unit = {
+ other.columnFilterMap.foreach( e => {
+ mergeUnion(e._1, e._2)
+ })
+ }
+
+ /**
+ * This will allow us to merge all filters in the existing collection
+ * to the filters in the other collection. All merges are done as a result
+ * of a AND operator
+ *
+ * @param other The column filter from the other collection
+ */
+ def mergeIntersect(other:ColumnFilterCollection): Unit = {
+ other.columnFilterMap.foreach( e => {
+ val existingColumnFilter = columnFilterMap.get(e._1)
+ if (existingColumnFilter.isEmpty) {
+ columnFilterMap += e
+ } else {
+ existingColumnFilter.get.mergeIntersect(e._2)
+ }
+ })
+ }
+
+ /**
+ * This will collect all the filter information in a way that is optimized
+ * for the HBase filter commend. Allowing the filter to be accessed
+ * with columnFamily and qualifier information
+ *
+ * @param schemaDefinitionMap Schema Map that will help us map the right filters
+ * to the correct columns
+ * @return HashMap oc column filters
+ */
+ def generateFamilyQualifiterFilterMap(schemaDefinitionMap:
+ java.util.HashMap[String, SchemaQualifierDefinition]):
+ util.HashMap[ColumnFamilyQualifierMapKeyWrapper, ColumnFilter] = {
+ val familyQualifierFilterMap =
+ new util.HashMap[ColumnFamilyQualifierMapKeyWrapper, ColumnFilter]()
+
+ columnFilterMap.foreach( e => {
+ val definition = schemaDefinitionMap.get(e._1)
+ //Don't add rowKeyFilter
+ if (definition.columnFamilyBytes.size > 0) {
+ familyQualifierFilterMap.put(
+ new ColumnFamilyQualifierMapKeyWrapper(
+ definition.columnFamilyBytes, 0, definition.columnFamilyBytes.length,
+ definition.qualifierBytes, 0, definition.qualifierBytes.length), e._2)
+ }
+ })
+ familyQualifierFilterMap
+ }
+
+ override def toString:String = {
+ val strBuilder = new StringBuilder
+ columnFilterMap.foreach( e => strBuilder.append(e))
+ strBuilder.toString()
+ }
+}
+
+/**
+ * Status object to store static functions but also to hold last executed
+ * information that can be used for unit testing.
+ */
+object DefaultSourceStaticUtils {
+
+ //This will contain the last 5 filters and required fields used in buildScan
+ // These values can be used in unit testing to make sure we are converting
+ // The Spark SQL input correctly
+ val lastFiveExecutionRules =
+ new ConcurrentLinkedQueue[ExecutionRuleForUnitTesting]()
+
+ /**
+ * This method is to populate the lastFiveExecutionRules for unit test perposes
+ * This method is not thread safe.
+ *
+ * @param columnFilterCollection The filters in the last job
+ * @param requiredQualifierDefinitionArray The required columns in the last job
+ */
+ def populateLatestExecutionRules(columnFilterCollection: ColumnFilterCollection,
+ requiredQualifierDefinitionArray: mutable.MutableList[SchemaQualifierDefinition]):Unit = {
+ lastFiveExecutionRules.add(new ExecutionRuleForUnitTesting(
+ columnFilterCollection, requiredQualifierDefinitionArray))
+ while (lastFiveExecutionRules.size() > 5) {
+ lastFiveExecutionRules.poll()
+ }
+ }
+
+ /**
+ * This method will convert the result content from HBase into the
+ * SQL value type that is requested by the Spark SQL schema definition
+ *
+ * @param columnName The name of the SparkSQL Column
+ * @param schemaMappingDefinition The schema definition map
+ * @param r The result object from HBase
+ * @return The converted object type
+ */
+ def getValue(columnName: String,
+ schemaMappingDefinition: java.util.HashMap[String, SchemaQualifierDefinition],
+ r: Result): Any = {
+
+ val columnDef = schemaMappingDefinition.get(columnName)
+
+ if (columnDef == null) throw new Throwable("Unknown column:" + columnName)
+
+
+ if (columnDef.columnFamilyBytes.isEmpty) {
+ val roKey = r.getRow
+
+ columnDef.columnSparkSqlType match {
+ case IntegerType => Bytes.toInt(roKey)
+ case LongType => Bytes.toLong(roKey)
+ case FloatType => Bytes.toFloat(roKey)
+ case DoubleType => Bytes.toDouble(roKey)
+ case StringType => Bytes.toString(roKey)
+ case TimestampType => new Timestamp(Bytes.toLong(roKey))
+ case _ => Bytes.toString(roKey)
+ }
+ } else {
+ val cellByteValue =
+ r.getColumnLatestCell(columnDef.columnFamilyBytes, columnDef.qualifierBytes)
+ if (cellByteValue == null) null
+ else columnDef.columnSparkSqlType match {
+ case IntegerType => Bytes.toInt(cellByteValue.getValueArray,
+ cellByteValue.getValueOffset, cellByteValue.getValueLength)
+ case LongType => Bytes.toLong(cellByteValue.getValueArray,
+ cellByteValue.getValueOffset, cellByteValue.getValueLength)
+ case FloatType => Bytes.toFloat(cellByteValue.getValueArray,
+ cellByteValue.getValueOffset)
+ case DoubleType => Bytes.toDouble(cellByteValue.getValueArray,
+ cellByteValue.getValueOffset)
+ case StringType => Bytes.toString(cellByteValue.getValueArray,
+ cellByteValue.getValueOffset, cellByteValue.getValueLength)
+ case TimestampType => new Timestamp(Bytes.toLong(cellByteValue.getValueArray,
+ cellByteValue.getValueOffset, cellByteValue.getValueLength))
+ case _ => Bytes.toString(cellByteValue.getValueArray,
+ cellByteValue.getValueOffset, cellByteValue.getValueLength)
+ }
+ }
+ }
+
+ /**
+ * This will convert the value from SparkSQL to be stored into HBase using the
+ * right byte Type
+ *
+ * @param columnName SparkSQL column name
+ * @param schemaMappingDefinition Schema definition map
+ * @param value String value from SparkSQL
+ * @return Returns the byte array to go into HBase
+ */
+ def getByteValue(columnName: String,
+ schemaMappingDefinition: java.util.HashMap[String, SchemaQualifierDefinition],
+ value: String): Array[Byte] = {
+
+ val columnDef = schemaMappingDefinition.get(columnName)
+
+ if (columnDef == null) throw new Throwable("Unknown column:" + columnName)
+ else {
+ columnDef.columnSparkSqlType match {
+ case IntegerType => Bytes.toBytes(value.toInt)
+ case LongType => Bytes.toBytes(value.toLong)
+ case FloatType => Bytes.toBytes(value.toFloat)
+ case DoubleType => Bytes.toBytes(value.toDouble)
+ case StringType => Bytes.toBytes(value)
+ case TimestampType => Bytes.toBytes(value.toLong)
+ case _ => Bytes.toBytes(value)
+ }
+ }
+ }
+
+ class ExecutionRuleForUnitTesting(val columnFilterCollection: ColumnFilterCollection,
+ val requiredQualifierDefinitionArray:
+ mutable.MutableList[SchemaQualifierDefinition])
+}
\ No newline at end of file
diff --git a/hbase-spark/src/main/scala/org/apache/hadoop/hbase/spark/HBaseContext.scala b/hbase-spark/src/main/scala/org/apache/hadoop/hbase/spark/HBaseContext.scala
index f060fea..ab4dbf5 100644
--- a/hbase-spark/src/main/scala/org/apache/hadoop/hbase/spark/HBaseContext.scala
+++ b/hbase-spark/src/main/scala/org/apache/hadoop/hbase/spark/HBaseContext.scala
@@ -64,6 +64,8 @@ class HBaseContext(@transient sc: SparkContext,
val broadcastedConf = sc.broadcast(new SerializableWritable(config))
val credentialsConf = sc.broadcast(new SerializableWritable(job.getCredentials))
+ LatestHBaseContextCache.latest = this
+
if (tmpHdfsConfgFile != null && config != null) {
val fs = FileSystem.newInstance(config)
val tmpPath = new Path(tmpHdfsConfgFile)
@@ -568,3 +570,7 @@ class HBaseContext(@transient sc: SparkContext,
private[spark]
def fakeClassTag[T]: ClassTag[T] = ClassTag.AnyRef.asInstanceOf[ClassTag[T]]
}
+
+object LatestHBaseContextCache {
+ var latest:HBaseContext = null
+}
diff --git a/hbase-spark/src/test/scala/org/apache/hadoop/hbase/spark/DefaultSourceSuite.scala b/hbase-spark/src/test/scala/org/apache/hadoop/hbase/spark/DefaultSourceSuite.scala
new file mode 100644
index 0000000..67f7ff0
--- /dev/null
+++ b/hbase-spark/src/test/scala/org/apache/hadoop/hbase/spark/DefaultSourceSuite.scala
@@ -0,0 +1,262 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.hadoop.hbase.spark
+
+import org.apache.hadoop.hbase.client.{Put, ConnectionFactory}
+import org.apache.hadoop.hbase.util.Bytes
+import org.apache.hadoop.hbase.{TableName, HBaseTestingUtility}
+import org.apache.spark.sql.{DataFrame, SQLContext}
+import org.apache.spark.{SparkContext, Logging}
+import org.scalatest.{BeforeAndAfterAll, BeforeAndAfterEach, FunSuite}
+
+class DefaultSourceSuite extends FunSuite with
+BeforeAndAfterEach with BeforeAndAfterAll with Logging {
+ @transient var sc: SparkContext = null
+ var TEST_UTIL: HBaseTestingUtility = new HBaseTestingUtility
+
+ val tableName = "t1"
+ val columnFamily = "c"
+
+ var sqlContext:SQLContext = null
+ var df:DataFrame = null
+
+ override def beforeAll() {
+
+ TEST_UTIL.startMiniCluster
+
+ logInfo(" - minicluster started")
+ try
+ TEST_UTIL.deleteTable(TableName.valueOf(tableName))
+ catch {
+ case e: Exception => logInfo(" - no table " + tableName + " found")
+
+ }
+ logInfo(" - creating table " + tableName)
+ TEST_UTIL.createTable(TableName.valueOf(tableName), Bytes.toBytes(columnFamily))
+ logInfo(" - created table")
+
+ sc = new SparkContext("local", "test")
+
+ val connection = ConnectionFactory.createConnection(TEST_UTIL.getConfiguration)
+ val table = connection.getTable(TableName.valueOf("t1"))
+
+ try {
+ var put = new Put(Bytes.toBytes("get1"))
+ put.addColumn(Bytes.toBytes(columnFamily), Bytes.toBytes("a"), Bytes.toBytes("foo1"))
+ put.addColumn(Bytes.toBytes(columnFamily), Bytes.toBytes("b"), Bytes.toBytes("1"))
+ table.put(put)
+ put = new Put(Bytes.toBytes("get2"))
+ put.addColumn(Bytes.toBytes(columnFamily), Bytes.toBytes("a"), Bytes.toBytes("foo2"))
+ put.addColumn(Bytes.toBytes(columnFamily), Bytes.toBytes("b"), Bytes.toBytes("4"))
+ table.put(put)
+ put = new Put(Bytes.toBytes("get3"))
+ put.addColumn(Bytes.toBytes(columnFamily), Bytes.toBytes("a"), Bytes.toBytes("foo3"))
+ put.addColumn(Bytes.toBytes(columnFamily), Bytes.toBytes("b"), Bytes.toBytes("8"))
+ table.put(put)
+ put = new Put(Bytes.toBytes("get4"))
+ put.addColumn(Bytes.toBytes(columnFamily), Bytes.toBytes("a"), Bytes.toBytes("foo4"))
+ put.addColumn(Bytes.toBytes(columnFamily), Bytes.toBytes("b"), Bytes.toBytes("10"))
+ table.put(put)
+ put = new Put(Bytes.toBytes("get5"))
+ put.addColumn(Bytes.toBytes(columnFamily), Bytes.toBytes("a"), Bytes.toBytes("foo5"))
+ put.addColumn(Bytes.toBytes(columnFamily), Bytes.toBytes("b"), Bytes.toBytes("8"))
+ table.put(put)
+ } finally {
+ table.close()
+ connection.close()
+ }
+
+ new HBaseContext(sc, TEST_UTIL.getConfiguration)
+ sqlContext = new SQLContext(sc)
+
+ df = sqlContext.load("org.apache.hadoop.hbase.spark",
+ Map("hbase.columns.mapping" -> "KEY_FIELD STRING :key, A_FIELD STRING c:a, B_FIELD STRING c:b,",
+ "hbase.table" -> "t1"))
+
+ df.registerTempTable("hbaseTmp")
+ }
+
+ override def afterAll() {
+ TEST_UTIL.deleteTable(TableName.valueOf(tableName))
+ logInfo("shuting down minicluster")
+ TEST_UTIL.shutdownMiniCluster()
+
+ sc.stop()
+ }
+
+
+ /**
+ * A example of query three fields and also only using points for the filter
+ */
+ test("testPointOnlyRowKeyQuery") {
+ val results = sqlContext.sql("SELECT KEY_FIELD, B_FIELD, A_FIELD FROM hbaseTmp " +
+ "WHERE " +
+ "(KEY_FIELD = 'get1' or KEY_FIELD = 'get2' or KEY_FIELD = 'get3')").take(10)
+
+ val executionRules = DefaultSourceStaticUtils.lastFiveExecutionRules.poll()
+
+ assert(results.length == 3)
+
+ assert(executionRules.columnFilterCollection.columnFilterMap.size == 1)
+ val keyFieldFilter = executionRules.columnFilterCollection.columnFilterMap.get("KEY_FIELD").get
+ assert(keyFieldFilter.ranges.length == 0)
+ assert(keyFieldFilter.points.length == 3)
+ assert(Bytes.toString(keyFieldFilter.points.head).equals("get1"))
+ assert(Bytes.toString(keyFieldFilter.points(1)).equals("get2"))
+ assert(Bytes.toString(keyFieldFilter.points(2)).equals("get3"))
+
+ assert(executionRules.requiredQualifierDefinitionArray.length == 2)
+ }
+
+ /**
+ * A example of a OR merge between to ranges the result is one range
+ * Also an example of less then and greater then
+ */
+ test("testTwoRangeRowKeyQuery") {
+ val results = sqlContext.sql("SELECT KEY_FIELD, B_FIELD, A_FIELD FROM hbaseTmp " +
+ "WHERE " +
+ "( KEY_FIELD < 'get2' or KEY_FIELD > 'get3')").take(10)
+
+ val executionRules = DefaultSourceStaticUtils.lastFiveExecutionRules.poll()
+
+ assert(results.length == 3)
+
+ assert(executionRules.columnFilterCollection.columnFilterMap.size == 1)
+ val keyFieldFilter = executionRules.columnFilterCollection.columnFilterMap.get("KEY_FIELD").get
+ assert(keyFieldFilter.ranges.length == 2)
+ assert(keyFieldFilter.points.length == 0)
+
+ assert(executionRules.requiredQualifierDefinitionArray.length == 2)
+ }
+
+ /**
+ * A example of a AND merge between to ranges the result is one range
+ * Also an example of less then and equal to and greater then and equal to
+ */
+ test("testOneCombinedRangeRowKeyQuery") {
+ val results = sqlContext.sql("SELECT KEY_FIELD, B_FIELD, A_FIELD FROM hbaseTmp " +
+ "WHERE " +
+ "(KEY_FIELD <= 'get3' and KEY_FIELD >= 'get2')").take(10)
+
+ val executionRules = DefaultSourceStaticUtils.lastFiveExecutionRules.poll()
+
+ assert(results.length == 2)
+
+ assert(executionRules.columnFilterCollection.columnFilterMap.size == 1)
+ val keyFieldFilter = executionRules.columnFilterCollection.columnFilterMap.get("KEY_FIELD").get
+ assert(keyFieldFilter.ranges.length == 1)
+ assert(keyFieldFilter.points.length == 0)
+
+ assert(executionRules.requiredQualifierDefinitionArray.length == 2)
+ }
+
+ /**
+ * Do a select with no filters
+ */
+ test("testSelectOnlyQuery") {
+
+ val results = df.select("KEY_FIELD").take(10)
+ assert(results.length == 5)
+
+
+ val executionRules = DefaultSourceStaticUtils.lastFiveExecutionRules.poll()
+ assert(executionRules.columnFilterCollection == null)
+ assert(executionRules.requiredQualifierDefinitionArray.length == 0)
+
+ }
+
+ /**
+ * A complex query with one point and one range for both the
+ * rowKey and the a column
+ */
+ test("testSQLPointAndRangeCombo") {
+ val results = sqlContext.sql("SELECT KEY_FIELD FROM hbaseTmp " +
+ "WHERE " +
+ "(KEY_FIELD = 'get1' and B_FIELD < '3') or " +
+ "(KEY_FIELD >= 'get3' and B_FIELD = '8')").take(5)
+
+
+ assert(results.length == 3)
+
+ val executionRules = DefaultSourceStaticUtils.lastFiveExecutionRules.poll()
+
+ assert(executionRules.columnFilterCollection.columnFilterMap.size == 2)
+ val keyFieldFilter = executionRules.columnFilterCollection.columnFilterMap.get("KEY_FIELD").get
+ assert(keyFieldFilter.ranges.length == 1)
+ assert(keyFieldFilter.ranges.head.upperBound == null)
+ assert(Bytes.toString(keyFieldFilter.ranges.head.lowerBound).equals("get3"))
+ assert(keyFieldFilter.ranges.head.isLowerBoundEqualTo)
+ assert(keyFieldFilter.points.length == 1)
+ assert(Bytes.toString(keyFieldFilter.points.head).equals("get1"))
+
+ val bFieldFilter = executionRules.columnFilterCollection.columnFilterMap.get("B_FIELD").get
+ assert(bFieldFilter.ranges.length == 1)
+ assert(bFieldFilter.ranges.head.lowerBound.length == 0)
+ assert(Bytes.toString(bFieldFilter.ranges.head.upperBound).equals("3"))
+ assert(!bFieldFilter.ranges.head.isUpperBoundEqualTo)
+ assert(bFieldFilter.points.length == 1)
+ assert(Bytes.toString(bFieldFilter.points.head).equals("8"))
+
+ assert(executionRules.requiredQualifierDefinitionArray.length == 1)
+ assert(executionRules.requiredQualifierDefinitionArray.head.columnName.equals("B_FIELD"))
+ assert(executionRules.requiredQualifierDefinitionArray.head.columnFamily.equals("c"))
+ assert(executionRules.requiredQualifierDefinitionArray.head.qualifier.equals("b"))
+ }
+
+ /**
+ * A complex query with two complex ranges that doesn't merge into one
+ */
+ test("testTwoCompleteRangeNonMergeRowKeyQuery") {
+
+ val results = sqlContext.sql("SELECT KEY_FIELD, B_FIELD, A_FIELD FROM hbaseTmp " +
+ "WHERE " +
+ "( KEY_FIELD >= 'get1' and KEY_FIELD <= 'get2') or" +
+ "( KEY_FIELD > 'get3' and KEY_FIELD <= 'get5')").take(10)
+
+ val executionRules = DefaultSourceStaticUtils.lastFiveExecutionRules.poll()
+ assert(results.length == 4)
+
+ assert(executionRules.columnFilterCollection.columnFilterMap.size == 1)
+ val keyFieldFilter = executionRules.columnFilterCollection.columnFilterMap.get("KEY_FIELD").get
+ assert(keyFieldFilter.ranges.length == 2)
+ assert(keyFieldFilter.points.length == 0)
+
+ assert(executionRules.requiredQualifierDefinitionArray.length == 2)
+ }
+
+ /**
+ * A complex query with two complex ranges that does merge into one
+ */
+ test("testTwoCompleteRangeMergeRowKeyQuery") {
+ val results = sqlContext.sql("SELECT KEY_FIELD, B_FIELD, A_FIELD FROM hbaseTmp " +
+ "WHERE " +
+ "( KEY_FIELD >= 'get1' and KEY_FIELD <= 'get3') or" +
+ "( KEY_FIELD > 'get3' and KEY_FIELD <= 'get5')").take(10)
+
+ val executionRules = DefaultSourceStaticUtils.lastFiveExecutionRules.poll()
+
+ assert(results.length == 5)
+
+ assert(executionRules.columnFilterCollection.columnFilterMap.size == 1)
+ val keyFieldFilter = executionRules.columnFilterCollection.columnFilterMap.get("KEY_FIELD").get
+ assert(keyFieldFilter.ranges.length == 1)
+ assert(keyFieldFilter.points.length == 0)
+
+ assert(executionRules.requiredQualifierDefinitionArray.length == 2)
+ }
+}