diff --git a/hbase-spark/pom.xml b/hbase-spark/pom.xml
index e48f9e8..7417127 100644
--- a/hbase-spark/pom.xml
+++ b/hbase-spark/pom.xml
@@ -79,6 +79,12 @@
org.apache.spark
+ spark-sql_${scala.binary.version}
+ ${spark.version}
+ provided
+
+
+ org.apache.spark
spark-streaming_${scala.binary.version}
${spark.version}
diff --git a/hbase-spark/src/main/java/org/apache/hadoop/hbase/spark/PushDownFilterJava.java b/hbase-spark/src/main/java/org/apache/hadoop/hbase/spark/PushDownFilterJava.java
new file mode 100644
index 0000000..92691da
--- /dev/null
+++ b/hbase-spark/src/main/java/org/apache/hadoop/hbase/spark/PushDownFilterJava.java
@@ -0,0 +1,115 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.hadoop.hbase.spark;
+
+import org.apache.hadoop.hbase.Cell;
+import org.apache.hadoop.hbase.CellUtil;
+import org.apache.hadoop.hbase.exceptions.DeserializationException;
+import org.apache.hadoop.hbase.filter.FilterBase;
+
+import java.io.ByteArrayInputStream;
+import java.io.ByteArrayOutputStream;
+import java.io.IOException;
+import java.io.ObjectInput;
+import java.io.ObjectInputStream;
+import java.io.ObjectOutput;
+import java.io.ObjectOutputStream;
+import java.io.Serializable;
+import java.util.HashMap;
+
+public class PushDownFilterJava extends FilterBase implements Serializable {
+
+ HashMap qualifierFilterTupleList;
+
+ public PushDownFilterJava(HashMap qualifierFilterTupleList) {
+ this.qualifierFilterTupleList = qualifierFilterTupleList;
+ }
+
+ @Override
+ public ReturnCode filterKeyValue(Cell c) throws IOException {
+ //TODO I'm sure this can be done better
+ ColumnFilter filter = qualifierFilterTupleList.get(new ColumnFamilyQualifierWrapper(
+ CellUtil.cloneFamily(c), CellUtil.cloneQualifier(c)));
+
+ if (filter == null) {
+ return ReturnCode.INCLUDE;
+ } else {
+ if (filter.validate(CellUtil.cloneValue(c))) {
+ return ReturnCode.INCLUDE;
+ } else {
+ return ReturnCode.NEXT_ROW;
+ }
+ }
+ }
+
+ public static PushDownFilterJava parseFrom(final byte [] bytes)
+ throws DeserializationException {
+
+ PushDownFilterJava result;
+
+ ByteArrayInputStream bis = new ByteArrayInputStream(bytes);
+ ObjectInput in = null;
+ try {
+ in = new ObjectInputStream(bis);
+ result = (PushDownFilterJava)in.readObject();
+
+ } catch (Exception e) {
+ throw new DeserializationException(e);
+ } finally {
+ try {
+ bis.close();
+ } catch (IOException ex) {
+ // ignore close exception
+ }
+ try {
+ if (in != null) {
+ in.close();
+ }
+ } catch (IOException ex) {
+ // ignore close exception
+ }
+ }
+ return result;
+ }
+
+ /**
+ * @return The filter serialized using pb
+ */
+ public byte [] toByteArray() throws IOException{
+ ByteArrayOutputStream bos = new ByteArrayOutputStream();
+ ObjectOutput out = null;
+ try {
+ out = new ObjectOutputStream(bos);
+ out.writeObject(this);
+ return bos.toByteArray();
+ } finally {
+ try {
+ if (out != null) {
+ out.close();
+ }
+ } catch (IOException ex) {
+ // ignore close exception
+ }
+ try {
+ bos.close();
+ } catch (IOException ex) {
+ // ignore close exception
+ }
+ }
+ }
+}
diff --git a/hbase-spark/src/main/scala/org/apache/hadoop/hbase/spark/DefaultSource.scala b/hbase-spark/src/main/scala/org/apache/hadoop/hbase/spark/DefaultSource.scala
new file mode 100644
index 0000000..c32cc06
--- /dev/null
+++ b/hbase-spark/src/main/scala/org/apache/hadoop/hbase/spark/DefaultSource.scala
@@ -0,0 +1,637 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.hadoop.hbase.spark
+
+import java.sql.Timestamp
+import java.util
+
+import org.apache.hadoop.hbase.client.{ConnectionFactory, Get, Result, Scan}
+import org.apache.hadoop.hbase.filter.Filter.ReturnCode
+import org.apache.hadoop.hbase.filter.FilterBase
+import org.apache.hadoop.hbase.util.Bytes
+import org.apache.hadoop.hbase.{CellUtil, Cell, TableName, HBaseConfiguration}
+import org.apache.spark.Logging
+import org.apache.spark.rdd.RDD
+import org.apache.spark.sql.{Row, SQLContext}
+import org.apache.spark.sql.sources._
+import org.apache.spark.sql.types._
+
+import scala.collection.mutable
+
+class DefaultSource extends RelationProvider {
+
+ val TABLE_KEY:String = "hbase.table"
+ val SCHEMA_COLUMNS_MAPPING_KEY:String = "hbase.columns.mapping"
+ val BATCHING_NUM_KEY:String = "hbase.batching.num"
+ val CATCHING_NUM_KEY:String = "hbase.catching.num"
+ val HBASE_CONFIG_RESOURCES_LOCATIONS:String = "hbase.config.resources"
+ val USE_HBASE_CONTEXT:String = "hbase.use.hbase.context"
+
+ override def createRelation(sqlContext: SQLContext,
+ parameters: Map[String, String]):
+ BaseRelation = {
+
+ println("baseRelations")
+
+ val tableName = parameters.getOrElse(TABLE_KEY, "")
+ val schemaMappingString = parameters.getOrElse(SCHEMA_COLUMNS_MAPPING_KEY, "")
+ val batchingNumStr = parameters.getOrElse(BATCHING_NUM_KEY, "1000")
+ val catchingNumStr = parameters.getOrElse(CATCHING_NUM_KEY, "1000")
+ val hbaseConfigResources = parameters.getOrElse(HBASE_CONFIG_RESOURCES_LOCATIONS, "")
+ val useHBaseReources = parameters.getOrElse(USE_HBASE_CONTEXT, "true")
+
+ if (tableName.isEmpty) {
+ new Throwable("Invalid value for " + TABLE_KEY +" '" + tableName + "'")
+ }
+
+ val batchingNum:Int = try {
+ batchingNumStr.toInt
+ } catch {
+ case e:Exception => throw
+ new Throwable("Invalid value for " + BATCHING_NUM_KEY +" '" + batchingNumStr + "'", e )
+ }
+
+ val catchingNum:Int = try {
+ catchingNumStr.toInt
+ } catch {
+ case e:Exception => throw
+ new Throwable("Invalid value for " + CATCHING_NUM_KEY +" '" + catchingNumStr + "'", e )
+ }
+
+ new HBaseRelation(tableName,
+ generateSchemaMappingMap(schemaMappingString),
+ batchingNum.toInt,
+ catchingNum.toInt,
+ hbaseConfigResources,
+ useHBaseReources.equalsIgnoreCase("true"))(sqlContext)
+ }
+
+ def generateSchemaMappingMap(schemaMappingString:String): mutable.Map[String, SchemaQualifierDefinition] = {
+ try {
+ val columnDefinitions = schemaMappingString.split(',')
+ val resultingMap = new mutable.HashMap[String, SchemaQualifierDefinition]()
+ columnDefinitions.map(cd => {
+ val parts = cd.trim.split(' ')
+ val hbaseDefinitionParts = if (parts(2).charAt(0) == ':') {
+ Array[String]("", "key")
+ } else {
+ parts(2).split(':')
+ }
+ resultingMap.+=((parts(0), new SchemaQualifierDefinition(parts(0),
+ parts(1), hbaseDefinitionParts(0), hbaseDefinitionParts(1))))
+ })
+ resultingMap
+ } catch {
+ case e:Exception => throw
+ new Throwable("Invalid value for " + SCHEMA_COLUMNS_MAPPING_KEY +
+ " '" + schemaMappingString + "'", e )
+ }
+ }
+}
+
+class HBaseRelation (tableName:String,
+ schemaMappingDefinition:mutable.Map[String, SchemaQualifierDefinition],
+ batchingNum:Int,
+ cachingNum:Int,
+ configResources:String,
+ useHBaseContext:Boolean) (
+ @transient val sqlContext:SQLContext)
+ extends BaseRelation with PrunedFilteredScan with Logging {
+
+ //create or get latest HBaseContext
+ @transient val hbaseContext:HBaseContext = if (useHBaseContext) {
+ LatestHBaseContextCache.latest
+ } else {
+ val config = HBaseConfiguration.create()
+ configResources.split(",").foreach( r => config.addResource(r))
+ new HBaseContext(sqlContext.sparkContext, config)
+ }
+
+ override def schema: StructType = {
+ println("schema")
+ val result = new StructType(schemaMappingDefinition.values.map(c => {
+ println(" - columnName:" + c.columnName + " c.columnSparkSqlType:" + c.columnSparkSqlType)
+ val metadata = new MetadataBuilder().putString("name", c.columnName)
+
+ new StructField(c.columnName, c.columnSparkSqlType, nullable = true, metadata.build())
+ }).toArray)
+ //TODO push schema to listener
+ result
+ }
+
+ override def buildScan(requiredColumns: Array[String], filters: Array[Filter]): RDD[Row] = {
+ println("buildScan")
+ filters.foreach(f => {
+ println("--- Start Root Filter")
+ printlnOutFilter(f)
+ println("--- Finished Root Filter")
+ })
+ val columnFilterCollection = buildColumnFilterCollection(filters)
+ println("columnFilterCollection:" + columnFilterCollection)
+
+ val serializableMap = new java.util.HashMap[String, SchemaQualifierDefinition]
+ schemaMappingDefinition.foreach( e => serializableMap.put(e._1, e._2))
+
+ var resultRDD: RDD[Row] = null
+
+ if (columnFilterCollection != null) {
+ val pushDownFilterJava = new PushDownFilterJava(columnFilterCollection.generateFamilyQualifiterFilterMap(serializableMap))
+
+ val getList = new util.ArrayList[Get]()
+ val rddList = new util.ArrayList[RDD[Row]]()
+
+ val it = columnFilterCollection.columnFilterMap.iterator
+ while (it.hasNext) {
+ val e = it.next()
+ val columnDefinition = schemaMappingDefinition.getOrElse(e._1, null)
+ //check is a rowKey
+ if (columnDefinition != null && columnDefinition.columnFamily.isEmpty) {
+ //add points to getList
+ e._2.points.foreach(p => getList.add(new Get(p)))
+
+ val rangeIt = e._2.ranges.iterator
+
+ while (rangeIt.hasNext) {
+ val r = rangeIt.next()
+
+ val scan = new Scan()
+ scan.setBatch(batchingNum)
+ scan.setCaching(cachingNum)
+
+ if (pushDownFilterJava.qualifierFilterTupleList.size() > 0) {
+ scan.setFilter(pushDownFilterJava)
+ }
+
+ if (r.lowerBound != null && r.lowerBound.size > 0) {
+ if (r.isLowerBoundEqualTo) {
+ scan.setStartRow(r.lowerBound)
+ } else {
+ val newArray = new Array[Byte](r.lowerBound.length + 1)
+ System.arraycopy(r.lowerBound, 0, newArray, 0, r.lowerBound.length)
+ newArray(r.lowerBound.length) = Byte.MinValue
+ scan.setStartRow(newArray)
+ }
+ println( " [[ Lower: " + Bytes.toString(r.lowerBound) + " ]] ")
+ }
+ if (r.upperBound != null && r.upperBound.size > 0) {
+ if (r.isUpperBoundEqualTo) {
+ val newArray = new Array[Byte](r.upperBound.length + 1)
+ System.arraycopy(r.upperBound, 0, newArray, 0, r.upperBound.length)
+ newArray(r.upperBound.length) = Byte.MinValue
+ scan.setStopRow(newArray)
+ println( " [[ Upper=: " + Bytes.toString(r.upperBound) + " ]] ")
+ } else {
+ scan.setStopRow(r.upperBound)
+ println( " [[ Upper: " + Bytes.toString(r.upperBound) + " ]] ")
+ }
+ }
+
+ println("Scan:" + scan)
+
+ val rdd = hbaseContext.hbaseRDD(TableName.valueOf(tableName), scan).map(r => {
+ Row.fromSeq(requiredColumns.map(c => Utils.getValue(c, serializableMap, r._2)))
+ })
+ rddList.add(rdd)
+ }
+ }
+ }
+
+ for (i <- 0 until rddList.size()) {
+ if (resultRDD == null) resultRDD = rddList.get(i)
+ else {
+ resultRDD.union(rddList.get(i))
+ }
+ }
+
+ if (getList.size() > 0) {
+ val connection = ConnectionFactory.createConnection(hbaseContext.tmpHdfsConfiguration)
+ val table = connection.getTable(TableName.valueOf(tableName))
+ val results = table.get(getList)
+ val rowList = mutable.MutableList[Row]()
+ for (i <- 0 until results.length) {
+ val rowArray = requiredColumns.map(c => Utils.getValue(c, serializableMap, results(i)))
+ rowList += (Row.fromSeq(rowArray))
+ }
+ val getRDD = sqlContext.sparkContext.parallelize(rowList)
+ if (resultRDD == null) resultRDD = getRDD
+ else {
+ resultRDD.union(getRDD)
+ }
+ }
+ }
+ if (resultRDD == null) {
+ val scan = new Scan()
+ scan.setBatch(batchingNum)
+ scan.setCaching(cachingNum)
+
+ val rdd = hbaseContext.hbaseRDD(TableName.valueOf(tableName), scan).map(r => {
+ Row.fromSeq(requiredColumns.map(c => Utils.getValue(c, serializableMap, r._2)))
+ })
+ resultRDD=rdd
+ }
+ resultRDD
+ }
+
+ def buildColumnFilterCollection(filters: Array[Filter]): ColumnFilterCollection = {
+ var superCollection:ColumnFilterCollection = null
+ filters.foreach( f => {
+ val parentCollection = new ColumnFilterCollection
+ buildColumnFilterCollection(parentCollection, f)
+ if (superCollection == null) superCollection = parentCollection
+ else {superCollection.andAppend(parentCollection)}
+ })
+ superCollection
+ }
+
+ def buildColumnFilterCollection(parentFilterCollection:ColumnFilterCollection, filter:Filter): Unit = {
+ filter match {
+ case EqualTo(attr, value) =>
+ parentFilterCollection.orAppend(attr,
+ new ColumnFilter(Utils.getByteValue(attr,schemaMappingDefinition, value.toString)))
+ case LessThan(attr, value) =>
+ parentFilterCollection.orAppend(attr, new ColumnFilter(null,
+ new ScanRange(Utils.getByteValue(attr,schemaMappingDefinition, value.toString), false,
+ new Array[Byte](0), true)))
+
+ case GreaterThan(attr, value) =>
+ parentFilterCollection.orAppend(attr, new ColumnFilter(null,
+ new ScanRange(null, true, Utils.getByteValue(attr,schemaMappingDefinition, value.toString), false)))
+
+ case LessThanOrEqual(attr, value) =>
+ parentFilterCollection.orAppend(attr, new ColumnFilter(null,
+ new ScanRange(Utils.getByteValue(attr,schemaMappingDefinition, value.toString), true,
+ new Array[Byte](0), true)))
+
+ case GreaterThanOrEqual(attr, value) =>
+ parentFilterCollection.orAppend(attr, new ColumnFilter(null,
+ new ScanRange(null, true, Utils.getByteValue(attr,schemaMappingDefinition, value.toString), true)))
+
+ case Or(left, right) =>
+ //println("==OR")
+ //println("===OR1:" + parentFilterCollection)
+ buildColumnFilterCollection(parentFilterCollection, left)
+ val rightSideCollection = new ColumnFilterCollection
+ buildColumnFilterCollection(rightSideCollection, right)
+ //println("===OR2:" + rightSideCollection)
+ parentFilterCollection.orAppend(rightSideCollection)
+ //println("===OR3:" + parentFilterCollection)
+ case And(left, right) =>
+ //println("==AND")
+ buildColumnFilterCollection(parentFilterCollection, left)
+ //println("===AND1:" + parentFilterCollection)
+ val rightSideCollection = new ColumnFilterCollection
+ buildColumnFilterCollection(rightSideCollection, right)
+ //println("===AND2:" + rightSideCollection)
+ parentFilterCollection.andAppend(rightSideCollection)
+ //println("===AND3:" + parentFilterCollection)
+ case _ =>
+ println("Skipping filter: ")
+ }
+ }
+
+
+ def printlnOutFilter(f: Filter): Unit = {
+ f match {
+ case EqualTo(attr, value) => println(" - EqualTo", attr, value)
+ case LessThan(attr, value) => println(" - LessThen", attr, value)
+ case GreaterThan(attr, value) => println(" - GreaterThen", attr, value)
+ case LessThanOrEqual(attr, value) => println(" - LessThenOrEqual", attr, value)
+ case GreaterThanOrEqual(attr, value) => println(" - GreateThenOrEqual", attr, value)
+ case Or(left, right) =>
+ printlnOutFilter(left)
+ println(" OR ")
+ printlnOutFilter(right)
+ case And(left, right) =>
+ printlnOutFilter(left)
+ println(" AND ")
+ printlnOutFilter(right)
+ case _ =>
+ println("Skipping filter: " + f)
+ }
+ }
+}
+
+
+
+case class SchemaQualifierDefinition(columnName:String,
+ colType:String,
+ columnFamily:String,
+ qualifier:String) extends Serializable {
+ val columnFamilyBytes = Bytes.toBytes(columnFamily)
+ val qualifierBytes = Bytes.toBytes(qualifier)
+ val columnSparkSqlType = if (colType.equals("BOOLEAN")) BooleanType
+ else if (colType.equals("TINYINT")) IntegerType
+ else if (colType.equals("INT")) IntegerType
+ else if (colType.equals("BIGINT")) LongType
+ else if (colType.equals("FLOAT")) FloatType
+ else if (colType.equals("DOUBLE")) DoubleType
+ else if (colType.equals("STRING")) StringType
+ else if (colType.equals("TIMESTAMP")) TimestampType
+ else if (colType.equals("DECIMAL")) StringType //DataTypes.createDecimalType(precision, scale)
+ else throw new Throwable("Unsupported column type :" + colType)
+}
+
+class ScanRange(var upperBound:Array[Byte], var isUpperBoundEqualTo:Boolean,
+ var lowerBound:Array[Byte], var isLowerBoundEqualTo:Boolean) extends Serializable {
+
+ def mergeIntersect(other:ScanRange): Unit = {
+ val upperBoundCompare = compareRange(upperBound, other.upperBound)
+ val lowerBoundCompare = compareRange(lowerBound, other.lowerBound)
+
+ upperBound = if (upperBoundCompare <0) upperBound else other.upperBound
+ lowerBound = if (lowerBoundCompare >0) lowerBound else other.lowerBound
+
+ isLowerBoundEqualTo = if (lowerBoundCompare == 0)
+ isLowerBoundEqualTo && other.isLowerBoundEqualTo
+ else isLowerBoundEqualTo
+
+ isUpperBoundEqualTo = if (upperBoundCompare == 0)
+ isUpperBoundEqualTo && other.isUpperBoundEqualTo
+ else isUpperBoundEqualTo
+ }
+
+ def mergeUnion(other:ScanRange): Unit = {
+
+ val upperBoundCompare = compareRange(upperBound, other.upperBound)
+ val lowerBoundCompare = compareRange(lowerBound, other.lowerBound)
+
+ upperBound = if (upperBoundCompare >0) upperBound else other.upperBound
+ lowerBound = if (lowerBoundCompare <0) lowerBound else other.lowerBound
+
+ isLowerBoundEqualTo = if (lowerBoundCompare == 0)
+ isLowerBoundEqualTo || other.isLowerBoundEqualTo
+ else isLowerBoundEqualTo
+
+ isUpperBoundEqualTo = if (upperBoundCompare == 0)
+ isUpperBoundEqualTo || other.isUpperBoundEqualTo
+ else isUpperBoundEqualTo
+ }
+
+ def doesOverLap(other:ScanRange): Boolean = {
+ if (compareRange(other.upperBound, lowerBound) >= 0 ||
+ compareRange(other.lowerBound, upperBound) >= 0){
+ true
+ } else {
+ false
+ }
+ }
+
+ def compareRange(left:Array[Byte], right:Array[Byte]): Int = {
+ if (left == null && right == null) 0
+ else if (left == null && right != null) 1
+ else if (left != null && right == null) -1
+ else Bytes.compareTo(left, right)
+ }
+
+ override def toString():String = {
+ "ScanRange:(" + Bytes.toString(upperBound) + "," + isUpperBoundEqualTo + "," +
+ Bytes.toString(lowerBound) + "," + isLowerBoundEqualTo + ")"
+ }
+}
+
+class ColumnFilter (var currentPoint:Array[Byte] = null,
+ var currentRange:ScanRange = null) extends Serializable {
+ var ranges = new mutable.MutableList[ScanRange]()
+ if (currentRange != null ) ranges.+=(currentRange)
+
+ var points = new mutable.MutableList[Array[Byte]]()
+ if (currentPoint != null) {
+ points.+=(currentPoint)
+ }
+
+ def validate(value:Array[Byte]):Boolean = {
+ var result = false
+
+ points.foreach( p => {
+ if (Bytes.equals(p, value)) {
+ result = true
+ }
+ })
+
+ ranges.foreach( r => {
+ val upperBoundPass = r.upperBound == null ||
+ (r.isUpperBoundEqualTo && Bytes.compareTo(r.upperBound, value) >= 0) ||
+ (!r.isUpperBoundEqualTo && Bytes.compareTo(r.upperBound, value) > 0)
+ val lowerBoundPass = r.lowerBound == null || r.lowerBound.size == 0
+ (r.isLowerBoundEqualTo && Bytes.compareTo(r.lowerBound, value) <= 0) ||
+ (!r.isLowerBoundEqualTo && Bytes.compareTo(r.lowerBound, value) < 0)
+
+ println("Filter: " + Bytes.toString(value) + ":" + upperBoundPass + "," + lowerBoundPass + " " + result)
+ if (r.upperBound != null) println(" upper: " + Bytes.toString(r.upperBound))
+ if (r.lowerBound != null) println(" lower: " + Bytes.toString(r.lowerBound))
+
+ result = result || (upperBoundPass && lowerBoundPass)
+ })
+ result
+ }
+
+ def orAppend(other:ColumnFilter): Unit = {
+ other.points.foreach( p => points += p)
+
+ other.ranges.foreach( otherR => {
+ var doesOverLap = false
+ ranges.foreach{ r =>
+ if (r.doesOverLap(otherR)) {
+ r.mergeUnion(otherR)
+ doesOverLap = true
+ }}
+ if (!doesOverLap) ranges.+=(otherR)
+ })
+ }
+
+ def andAppend(other:ColumnFilter): Unit = {
+ val survivingPoints = new mutable.MutableList[Array[Byte]]()
+ points.foreach( p => {
+ other.points.foreach( otherP => {
+ if (Bytes.equals(p, otherP)) {
+ survivingPoints.+=(p)
+ }
+ })
+ })
+ points = survivingPoints
+
+ val survivingRanges = new mutable.MutableList[ScanRange]()
+
+ other.ranges.foreach( otherR => {
+ ranges.foreach( r => {
+ if (r.doesOverLap(otherR)) {
+ r.mergeIntersect(otherR)
+ survivingRanges += r
+ }
+ })
+ })
+ ranges = survivingRanges
+ }
+
+ override def toString:String = {
+ val strBuilder = new StringBuilder
+ strBuilder.append("(points:(")
+ var isFirst = true
+ points.foreach( p => {
+ if (isFirst) isFirst = false
+ else strBuilder.append(",")
+ strBuilder.append(Bytes.toString(p))
+ })
+ strBuilder.append("),ranges:")
+ isFirst = true
+ ranges.foreach( r => {
+ if (isFirst) isFirst = false
+ else strBuilder.append(",")
+ strBuilder.append(r)
+ })
+ strBuilder.append("))")
+ strBuilder.toString()
+ }
+}
+
+class ColumnFilterCollection {
+ val columnFilterMap = new mutable.HashMap[String, ColumnFilter]
+
+ def orAppend(column:String, columnFilter:ColumnFilter): Unit = {
+ val existingFilter = columnFilterMap.get(column)
+ if (existingFilter.isEmpty) {
+ columnFilterMap.+=((column, columnFilter))
+ } else {
+ existingFilter.get.orAppend(columnFilter)
+ }
+ }
+
+ def orAppend(columnFilterColumn:ColumnFilterCollection): Unit = {
+ columnFilterColumn.columnFilterMap.foreach( e => {
+ orAppend(e._1, e._2)
+ })
+ }
+
+ def andAppend(columnFilterColumn:ColumnFilterCollection): Unit = {
+ columnFilterColumn.columnFilterMap.foreach( e => {
+ val existingColumnFilter = columnFilterMap.get(e._1)
+ if (existingColumnFilter.isEmpty) {
+ columnFilterMap += e
+ } else {
+ existingColumnFilter.get.andAppend(e._2)
+ }
+ })
+ }
+
+ def generateFamilyQualifiterFilterMap(schemaDefinitionMap:
+ java.util.HashMap[String, SchemaQualifierDefinition]):
+ util.HashMap[ColumnFamilyQualifierWrapper, ColumnFilter] = {
+ val familyQualifierFilterMap = new util.HashMap[ColumnFamilyQualifierWrapper, ColumnFilter]()
+ columnFilterMap.foreach( e => {
+ val definition = schemaDefinitionMap.get(e._1)
+ //Don't add rowKeyFilter
+ if (definition.columnFamilyBytes.size > 0) {
+ familyQualifierFilterMap.put(
+ new ColumnFamilyQualifierWrapper(definition.columnFamilyBytes, definition.qualifierBytes), e._2)
+ }
+
+ })
+ familyQualifierFilterMap
+ }
+
+ override def toString:String = {
+ val strBuilder = new StringBuilder
+ columnFilterMap.foreach( e => strBuilder.append(e))
+ strBuilder.toString()
+ }
+}
+
+class ColumnFamilyQualifierWrapper(val columnFamily:Array[Byte], val qualifier:Array[Byte])
+ extends Serializable{
+
+ override def equals(other:Any): Boolean = {
+ if (other.isInstanceOf[ColumnFamilyQualifierWrapper]) {
+ val otherWrapper = other.asInstanceOf[ColumnFamilyQualifierWrapper]
+ Bytes.compareTo(columnFamily, otherWrapper.columnFamily) == 0 &&
+ Bytes.compareTo(qualifier, otherWrapper.qualifier) == 0
+ } else {
+ false
+ }
+ }
+
+ override def hashCode():Int = {
+ Bytes.hashCode(columnFamily) + Bytes.hashCode(qualifier)
+ }
+}
+
+object Utils {
+ def getValue(columnName: String,
+ schemaMappingDefinition: java.util.HashMap[String, SchemaQualifierDefinition],
+ r: Result): Any = {
+
+ val columnDef = schemaMappingDefinition.get(columnName)
+
+ if (columnDef == null) throw new Throwable("Unknown column:" + columnName)
+
+
+ if (columnDef.columnFamilyBytes.isEmpty) {
+ val roKey = r.getRow
+
+ columnDef.columnSparkSqlType match {
+ case IntegerType => Bytes.toInt(roKey)
+ case LongType => Bytes.toLong(roKey)
+ case FloatType => Bytes.toFloat(roKey)
+ case DoubleType => Bytes.toDouble(roKey)
+ case StringType => Bytes.toString(roKey)
+ case TimestampType => new Timestamp(Bytes.toLong(roKey))
+ case _ => Bytes.toString(roKey)
+ }
+ } else {
+ val cellByteValue =
+ r.getColumnLatestCell(columnDef.columnFamilyBytes, columnDef.qualifierBytes)
+ if (cellByteValue == null) null
+ else columnDef.columnSparkSqlType match {
+ case IntegerType => Bytes.toInt(cellByteValue.getValueArray,
+ cellByteValue.getValueOffset, cellByteValue.getValueLength)
+ case LongType => Bytes.toLong(cellByteValue.getValueArray,
+ cellByteValue.getValueOffset, cellByteValue.getValueLength)
+ case FloatType => Bytes.toFloat(cellByteValue.getValueArray,
+ cellByteValue.getValueOffset)
+ case DoubleType => Bytes.toDouble(cellByteValue.getValueArray,
+ cellByteValue.getValueOffset)
+ case StringType => Bytes.toString(cellByteValue.getValueArray,
+ cellByteValue.getValueOffset, cellByteValue.getValueLength)
+ case TimestampType => new Timestamp(Bytes.toLong(cellByteValue.getValueArray,
+ cellByteValue.getValueOffset, cellByteValue.getValueLength))
+ case _ => Bytes.toString(cellByteValue.getValueArray,
+ cellByteValue.getValueOffset, cellByteValue.getValueLength)
+ }
+ }
+ }
+ def getByteValue(columnName: String,
+ schemaMappingDefinition: mutable.Map[String, SchemaQualifierDefinition],
+ value: String): Array[Byte] = {
+
+ val columnDef = schemaMappingDefinition.get(columnName)
+
+ if (columnDef == null) throw new Throwable("Unknown column:" + columnName)
+
+ if (columnDef.isEmpty) { throw new Throwable("Unknown column:" + columnName)}
+ else {
+ columnDef.get.columnSparkSqlType match {
+ case IntegerType => Bytes.toBytes(value.toInt)
+ case LongType => Bytes.toBytes(value.toLong)
+ case FloatType => Bytes.toBytes(value.toFloat)
+ case DoubleType => Bytes.toBytes(value.toDouble)
+ case StringType => Bytes.toBytes(value)
+ case TimestampType => Bytes.toBytes(value.toLong)
+ case _ => Bytes.toBytes(value)
+ }
+ }
+ }
+}
\ No newline at end of file
diff --git a/hbase-spark/src/main/scala/org/apache/hadoop/hbase/spark/HBaseContext.scala b/hbase-spark/src/main/scala/org/apache/hadoop/hbase/spark/HBaseContext.scala
index f060fea..ab4dbf5 100644
--- a/hbase-spark/src/main/scala/org/apache/hadoop/hbase/spark/HBaseContext.scala
+++ b/hbase-spark/src/main/scala/org/apache/hadoop/hbase/spark/HBaseContext.scala
@@ -64,6 +64,8 @@ class HBaseContext(@transient sc: SparkContext,
val broadcastedConf = sc.broadcast(new SerializableWritable(config))
val credentialsConf = sc.broadcast(new SerializableWritable(job.getCredentials))
+ LatestHBaseContextCache.latest = this
+
if (tmpHdfsConfgFile != null && config != null) {
val fs = FileSystem.newInstance(config)
val tmpPath = new Path(tmpHdfsConfgFile)
@@ -568,3 +570,7 @@ class HBaseContext(@transient sc: SparkContext,
private[spark]
def fakeClassTag[T]: ClassTag[T] = ClassTag.AnyRef.asInstanceOf[ClassTag[T]]
}
+
+object LatestHBaseContextCache {
+ var latest:HBaseContext = null
+}
diff --git a/hbase-spark/src/test/scala/org/apache/hadoop/hbase/spark/DefaultSourceSuite.scala b/hbase-spark/src/test/scala/org/apache/hadoop/hbase/spark/DefaultSourceSuite.scala
new file mode 100644
index 0000000..e03c0c2
--- /dev/null
+++ b/hbase-spark/src/test/scala/org/apache/hadoop/hbase/spark/DefaultSourceSuite.scala
@@ -0,0 +1,158 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.hadoop.hbase.spark
+
+import org.apache.hadoop.hbase.client.{Put, ConnectionFactory}
+import org.apache.hadoop.hbase.util.Bytes
+import org.apache.hadoop.hbase.{TableName, HBaseTestingUtility}
+import org.apache.spark.sql.SQLContext
+import org.apache.spark.{SparkContext, Logging}
+import org.scalatest.{BeforeAndAfterAll, BeforeAndAfterEach, FunSuite}
+
+class DefaultSourceSuite extends FunSuite with
+BeforeAndAfterEach with BeforeAndAfterAll with Logging {
+ @transient var sc: SparkContext = null
+ var TEST_UTIL: HBaseTestingUtility = new HBaseTestingUtility
+
+ val tableName = "t1"
+ val columnFamily = "c"
+
+ override def beforeAll() {
+
+ TEST_UTIL.startMiniCluster
+
+ logInfo(" - minicluster started")
+ try
+ TEST_UTIL.deleteTable(TableName.valueOf(tableName))
+ catch {
+ case e: Exception => logInfo(" - no table " + tableName + " found")
+
+ }
+ logInfo(" - creating table " + tableName)
+ TEST_UTIL.createTable(TableName.valueOf(tableName), Bytes.toBytes(columnFamily))
+ logInfo(" - created table")
+
+ sc = new SparkContext("local", "test")
+ }
+
+ override def afterAll() {
+ TEST_UTIL.deleteTable(TableName.valueOf(tableName))
+ logInfo("shuting down minicluster")
+ TEST_UTIL.shutdownMiniCluster()
+
+ sc.stop()
+ }
+
+ test("dataframe.select test") {
+ val config = TEST_UTIL.getConfiguration
+ val connection = ConnectionFactory.createConnection(config)
+ val table = connection.getTable(TableName.valueOf("t1"))
+
+ try {
+ var put = new Put(Bytes.toBytes("get1"))
+ put.addColumn(Bytes.toBytes(columnFamily), Bytes.toBytes("a"), Bytes.toBytes("foo1"))
+ put.addColumn(Bytes.toBytes(columnFamily), Bytes.toBytes("b"), Bytes.toBytes("1"))
+ table.put(put)
+ put = new Put(Bytes.toBytes("get2"))
+ put.addColumn(Bytes.toBytes(columnFamily), Bytes.toBytes("a"), Bytes.toBytes("foo2"))
+ put.addColumn(Bytes.toBytes(columnFamily), Bytes.toBytes("b"), Bytes.toBytes("4"))
+ table.put(put)
+ put = new Put(Bytes.toBytes("get3"))
+ put.addColumn(Bytes.toBytes(columnFamily), Bytes.toBytes("a"), Bytes.toBytes("foo3"))
+ put.addColumn(Bytes.toBytes(columnFamily), Bytes.toBytes("b"), Bytes.toBytes("8"))
+ table.put(put)
+ put = new Put(Bytes.toBytes("get4"))
+ put.addColumn(Bytes.toBytes(columnFamily), Bytes.toBytes("a"), Bytes.toBytes("foo4"))
+ put.addColumn(Bytes.toBytes(columnFamily), Bytes.toBytes("b"), Bytes.toBytes("10"))
+ table.put(put)
+ } finally {
+ table.close()
+ connection.close()
+ }
+
+ val hbaseContext = new HBaseContext(sc, config)
+ val sqlContext = new SQLContext(sc)
+
+ val df = sqlContext.load("org.apache.hadoop.hbase.spark",
+ Map("hbase.columns.mapping" -> "KEY_FIELD STRING :key, A_FIELD STRING c:a, B_FIELD STRING c:b,",
+ "hbase.table" -> "t1"))
+
+ println("simple select")
+ //df.select("KEY_FIELD").foreach(r => println(" - " + r))
+
+ println("tempTable or test")
+ df.registerTempTable("hbaseTmp")
+
+ sqlContext.sql("SELECT KEY_FIELD FROM hbaseTmp " +
+ "WHERE " +
+ "(KEY_FIELD = 'get1' and B_FIELD < '3') or " +
+ "(KEY_FIELD <= 'get3' and B_FIELD = '8')").foreach(r => println(" - " + r))
+
+ println("------------------------")
+
+ sqlContext.sql("SELECT KEY_FIELD FROM hbaseTmp " +
+ "WHERE " +
+ "(KEY_FIELD = 'get1' and B_FIELD < '3') or " +
+ "(KEY_FIELD < 'get3' and B_FIELD = '8')").foreach(r => println(" - " + r))
+
+ println("------------------------")
+
+ sqlContext.sql("SELECT KEY_FIELD FROM hbaseTmp " +
+ "WHERE " +
+ "(KEY_FIELD > 'get1')").foreach(r => println(" - " + r))
+ println("------------------------")
+
+ sqlContext.sql("SELECT KEY_FIELD FROM hbaseTmp " +
+ "WHERE " +
+ "(KEY_FIELD >= 'get1')").foreach(r => println(" - " + r))
+
+ println("------------------------")
+
+ sqlContext.sql("SELECT KEY_FIELD FROM hbaseTmp " +
+ "WHERE " +
+ "(KEY_FIELD = 'get2' or KEY_FIELD = 'get1') and " +
+ "(B_FIELD < '3' or B_FIELD = '4')").foreach(r => println(" - " + r))
+
+ println("------------------------")
+
+ sqlContext.sql("SELECT KEY_FIELD FROM hbaseTmp " +
+ "WHERE " +
+ "(KEY_FIELD > 'get1' and KEY_FIELD < 'get3') or " +
+ "(KEY_FIELD <= 'get4' and B_FIELD = '8')").foreach(r => println(" - " + r))
+
+
+ sqlContext.sql("SELECT KEY_FIELD FROM hbaseTmp " +
+ "WHERE " +
+ "(A_FIELD = 'foo1' and B_FIELD < '3') or " +
+ "(A_FIELD < 'foo3' and B_FIELD = '8')").foreach(r => println(" - " + r))
+
+ println("------------------------")
+
+ sqlContext.sql("SELECT KEY_FIELD FROM hbaseTmp " +
+ "WHERE " +
+ "(A_FIELD = 'foo2' or A_FIELD = 'foo1') and " +
+ "(B_FIELD < '3' or B_FIELD = '4')").foreach(r => println(" - " + r))
+
+ println("------------------------")
+
+ sqlContext.sql("SELECT KEY_FIELD FROM hbaseTmp " +
+ "WHERE " +
+ "B_FIELD < '3' or " +
+ " B_FIELD <= '8'").foreach(r => println(" - " + r))
+ }
+}