diff --git a/hbase-spark/pom.xml b/hbase-spark/pom.xml index e48f9e8..7417127 100644 --- a/hbase-spark/pom.xml +++ b/hbase-spark/pom.xml @@ -79,6 +79,12 @@ org.apache.spark + spark-sql_${scala.binary.version} + ${spark.version} + provided + + + org.apache.spark spark-streaming_${scala.binary.version} ${spark.version} diff --git a/hbase-spark/src/main/scala/org/apache/hadoop/hbase/spark/DefaultSource.scala b/hbase-spark/src/main/scala/org/apache/hadoop/hbase/spark/DefaultSource.scala new file mode 100644 index 0000000..c32cc06 --- /dev/null +++ b/hbase-spark/src/main/scala/org/apache/hadoop/hbase/spark/DefaultSource.scala @@ -0,0 +1,637 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.hadoop.hbase.spark + +import java.sql.Timestamp +import java.util + +import org.apache.hadoop.hbase.client.{ConnectionFactory, Get, Result, Scan} +import org.apache.hadoop.hbase.filter.Filter.ReturnCode +import org.apache.hadoop.hbase.filter.FilterBase +import org.apache.hadoop.hbase.util.Bytes +import org.apache.hadoop.hbase.{CellUtil, Cell, TableName, HBaseConfiguration} +import org.apache.spark.Logging +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.{Row, SQLContext} +import org.apache.spark.sql.sources._ +import org.apache.spark.sql.types._ + +import scala.collection.mutable + +class DefaultSource extends RelationProvider { + + val TABLE_KEY:String = "hbase.table" + val SCHEMA_COLUMNS_MAPPING_KEY:String = "hbase.columns.mapping" + val BATCHING_NUM_KEY:String = "hbase.batching.num" + val CATCHING_NUM_KEY:String = "hbase.catching.num" + val HBASE_CONFIG_RESOURCES_LOCATIONS:String = "hbase.config.resources" + val USE_HBASE_CONTEXT:String = "hbase.use.hbase.context" + + override def createRelation(sqlContext: SQLContext, + parameters: Map[String, String]): + BaseRelation = { + + println("baseRelations") + + val tableName = parameters.getOrElse(TABLE_KEY, "") + val schemaMappingString = parameters.getOrElse(SCHEMA_COLUMNS_MAPPING_KEY, "") + val batchingNumStr = parameters.getOrElse(BATCHING_NUM_KEY, "1000") + val catchingNumStr = parameters.getOrElse(CATCHING_NUM_KEY, "1000") + val hbaseConfigResources = parameters.getOrElse(HBASE_CONFIG_RESOURCES_LOCATIONS, "") + val useHBaseReources = parameters.getOrElse(USE_HBASE_CONTEXT, "true") + + if (tableName.isEmpty) { + new Throwable("Invalid value for " + TABLE_KEY +" '" + tableName + "'") + } + + val batchingNum:Int = try { + batchingNumStr.toInt + } catch { + case e:Exception => throw + new Throwable("Invalid value for " + BATCHING_NUM_KEY +" '" + batchingNumStr + "'", e ) + } + + val catchingNum:Int = try { + catchingNumStr.toInt + } catch { + case e:Exception => throw + new Throwable("Invalid value for " + CATCHING_NUM_KEY +" '" + catchingNumStr + "'", e ) + } + + new HBaseRelation(tableName, + generateSchemaMappingMap(schemaMappingString), + batchingNum.toInt, + catchingNum.toInt, + hbaseConfigResources, + useHBaseReources.equalsIgnoreCase("true"))(sqlContext) + } + + def generateSchemaMappingMap(schemaMappingString:String): mutable.Map[String, SchemaQualifierDefinition] = { + try { + val columnDefinitions = schemaMappingString.split(',') + val resultingMap = new mutable.HashMap[String, SchemaQualifierDefinition]() + columnDefinitions.map(cd => { + val parts = cd.trim.split(' ') + val hbaseDefinitionParts = if (parts(2).charAt(0) == ':') { + Array[String]("", "key") + } else { + parts(2).split(':') + } + resultingMap.+=((parts(0), new SchemaQualifierDefinition(parts(0), + parts(1), hbaseDefinitionParts(0), hbaseDefinitionParts(1)))) + }) + resultingMap + } catch { + case e:Exception => throw + new Throwable("Invalid value for " + SCHEMA_COLUMNS_MAPPING_KEY + + " '" + schemaMappingString + "'", e ) + } + } +} + +class HBaseRelation (tableName:String, + schemaMappingDefinition:mutable.Map[String, SchemaQualifierDefinition], + batchingNum:Int, + cachingNum:Int, + configResources:String, + useHBaseContext:Boolean) ( + @transient val sqlContext:SQLContext) + extends BaseRelation with PrunedFilteredScan with Logging { + + //create or get latest HBaseContext + @transient val hbaseContext:HBaseContext = if (useHBaseContext) { + LatestHBaseContextCache.latest + } else { + val config = HBaseConfiguration.create() + configResources.split(",").foreach( r => config.addResource(r)) + new HBaseContext(sqlContext.sparkContext, config) + } + + override def schema: StructType = { + println("schema") + val result = new StructType(schemaMappingDefinition.values.map(c => { + println(" - columnName:" + c.columnName + " c.columnSparkSqlType:" + c.columnSparkSqlType) + val metadata = new MetadataBuilder().putString("name", c.columnName) + + new StructField(c.columnName, c.columnSparkSqlType, nullable = true, metadata.build()) + }).toArray) + //TODO push schema to listener + result + } + + override def buildScan(requiredColumns: Array[String], filters: Array[Filter]): RDD[Row] = { + println("buildScan") + filters.foreach(f => { + println("--- Start Root Filter") + printlnOutFilter(f) + println("--- Finished Root Filter") + }) + val columnFilterCollection = buildColumnFilterCollection(filters) + println("columnFilterCollection:" + columnFilterCollection) + + val serializableMap = new java.util.HashMap[String, SchemaQualifierDefinition] + schemaMappingDefinition.foreach( e => serializableMap.put(e._1, e._2)) + + var resultRDD: RDD[Row] = null + + if (columnFilterCollection != null) { + val pushDownFilterJava = new PushDownFilterJava(columnFilterCollection.generateFamilyQualifiterFilterMap(serializableMap)) + + val getList = new util.ArrayList[Get]() + val rddList = new util.ArrayList[RDD[Row]]() + + val it = columnFilterCollection.columnFilterMap.iterator + while (it.hasNext) { + val e = it.next() + val columnDefinition = schemaMappingDefinition.getOrElse(e._1, null) + //check is a rowKey + if (columnDefinition != null && columnDefinition.columnFamily.isEmpty) { + //add points to getList + e._2.points.foreach(p => getList.add(new Get(p))) + + val rangeIt = e._2.ranges.iterator + + while (rangeIt.hasNext) { + val r = rangeIt.next() + + val scan = new Scan() + scan.setBatch(batchingNum) + scan.setCaching(cachingNum) + + if (pushDownFilterJava.qualifierFilterTupleList.size() > 0) { + scan.setFilter(pushDownFilterJava) + } + + if (r.lowerBound != null && r.lowerBound.size > 0) { + if (r.isLowerBoundEqualTo) { + scan.setStartRow(r.lowerBound) + } else { + val newArray = new Array[Byte](r.lowerBound.length + 1) + System.arraycopy(r.lowerBound, 0, newArray, 0, r.lowerBound.length) + newArray(r.lowerBound.length) = Byte.MinValue + scan.setStartRow(newArray) + } + println( " [[ Lower: " + Bytes.toString(r.lowerBound) + " ]] ") + } + if (r.upperBound != null && r.upperBound.size > 0) { + if (r.isUpperBoundEqualTo) { + val newArray = new Array[Byte](r.upperBound.length + 1) + System.arraycopy(r.upperBound, 0, newArray, 0, r.upperBound.length) + newArray(r.upperBound.length) = Byte.MinValue + scan.setStopRow(newArray) + println( " [[ Upper=: " + Bytes.toString(r.upperBound) + " ]] ") + } else { + scan.setStopRow(r.upperBound) + println( " [[ Upper: " + Bytes.toString(r.upperBound) + " ]] ") + } + } + + println("Scan:" + scan) + + val rdd = hbaseContext.hbaseRDD(TableName.valueOf(tableName), scan).map(r => { + Row.fromSeq(requiredColumns.map(c => Utils.getValue(c, serializableMap, r._2))) + }) + rddList.add(rdd) + } + } + } + + for (i <- 0 until rddList.size()) { + if (resultRDD == null) resultRDD = rddList.get(i) + else { + resultRDD.union(rddList.get(i)) + } + } + + if (getList.size() > 0) { + val connection = ConnectionFactory.createConnection(hbaseContext.tmpHdfsConfiguration) + val table = connection.getTable(TableName.valueOf(tableName)) + val results = table.get(getList) + val rowList = mutable.MutableList[Row]() + for (i <- 0 until results.length) { + val rowArray = requiredColumns.map(c => Utils.getValue(c, serializableMap, results(i))) + rowList += (Row.fromSeq(rowArray)) + } + val getRDD = sqlContext.sparkContext.parallelize(rowList) + if (resultRDD == null) resultRDD = getRDD + else { + resultRDD.union(getRDD) + } + } + } + if (resultRDD == null) { + val scan = new Scan() + scan.setBatch(batchingNum) + scan.setCaching(cachingNum) + + val rdd = hbaseContext.hbaseRDD(TableName.valueOf(tableName), scan).map(r => { + Row.fromSeq(requiredColumns.map(c => Utils.getValue(c, serializableMap, r._2))) + }) + resultRDD=rdd + } + resultRDD + } + + def buildColumnFilterCollection(filters: Array[Filter]): ColumnFilterCollection = { + var superCollection:ColumnFilterCollection = null + filters.foreach( f => { + val parentCollection = new ColumnFilterCollection + buildColumnFilterCollection(parentCollection, f) + if (superCollection == null) superCollection = parentCollection + else {superCollection.andAppend(parentCollection)} + }) + superCollection + } + + def buildColumnFilterCollection(parentFilterCollection:ColumnFilterCollection, filter:Filter): Unit = { + filter match { + case EqualTo(attr, value) => + parentFilterCollection.orAppend(attr, + new ColumnFilter(Utils.getByteValue(attr,schemaMappingDefinition, value.toString))) + case LessThan(attr, value) => + parentFilterCollection.orAppend(attr, new ColumnFilter(null, + new ScanRange(Utils.getByteValue(attr,schemaMappingDefinition, value.toString), false, + new Array[Byte](0), true))) + + case GreaterThan(attr, value) => + parentFilterCollection.orAppend(attr, new ColumnFilter(null, + new ScanRange(null, true, Utils.getByteValue(attr,schemaMappingDefinition, value.toString), false))) + + case LessThanOrEqual(attr, value) => + parentFilterCollection.orAppend(attr, new ColumnFilter(null, + new ScanRange(Utils.getByteValue(attr,schemaMappingDefinition, value.toString), true, + new Array[Byte](0), true))) + + case GreaterThanOrEqual(attr, value) => + parentFilterCollection.orAppend(attr, new ColumnFilter(null, + new ScanRange(null, true, Utils.getByteValue(attr,schemaMappingDefinition, value.toString), true))) + + case Or(left, right) => + //println("==OR") + //println("===OR1:" + parentFilterCollection) + buildColumnFilterCollection(parentFilterCollection, left) + val rightSideCollection = new ColumnFilterCollection + buildColumnFilterCollection(rightSideCollection, right) + //println("===OR2:" + rightSideCollection) + parentFilterCollection.orAppend(rightSideCollection) + //println("===OR3:" + parentFilterCollection) + case And(left, right) => + //println("==AND") + buildColumnFilterCollection(parentFilterCollection, left) + //println("===AND1:" + parentFilterCollection) + val rightSideCollection = new ColumnFilterCollection + buildColumnFilterCollection(rightSideCollection, right) + //println("===AND2:" + rightSideCollection) + parentFilterCollection.andAppend(rightSideCollection) + //println("===AND3:" + parentFilterCollection) + case _ => + println("Skipping filter: ") + } + } + + + def printlnOutFilter(f: Filter): Unit = { + f match { + case EqualTo(attr, value) => println(" - EqualTo", attr, value) + case LessThan(attr, value) => println(" - LessThen", attr, value) + case GreaterThan(attr, value) => println(" - GreaterThen", attr, value) + case LessThanOrEqual(attr, value) => println(" - LessThenOrEqual", attr, value) + case GreaterThanOrEqual(attr, value) => println(" - GreateThenOrEqual", attr, value) + case Or(left, right) => + printlnOutFilter(left) + println(" OR ") + printlnOutFilter(right) + case And(left, right) => + printlnOutFilter(left) + println(" AND ") + printlnOutFilter(right) + case _ => + println("Skipping filter: " + f) + } + } +} + + + +case class SchemaQualifierDefinition(columnName:String, + colType:String, + columnFamily:String, + qualifier:String) extends Serializable { + val columnFamilyBytes = Bytes.toBytes(columnFamily) + val qualifierBytes = Bytes.toBytes(qualifier) + val columnSparkSqlType = if (colType.equals("BOOLEAN")) BooleanType + else if (colType.equals("TINYINT")) IntegerType + else if (colType.equals("INT")) IntegerType + else if (colType.equals("BIGINT")) LongType + else if (colType.equals("FLOAT")) FloatType + else if (colType.equals("DOUBLE")) DoubleType + else if (colType.equals("STRING")) StringType + else if (colType.equals("TIMESTAMP")) TimestampType + else if (colType.equals("DECIMAL")) StringType //DataTypes.createDecimalType(precision, scale) + else throw new Throwable("Unsupported column type :" + colType) +} + +class ScanRange(var upperBound:Array[Byte], var isUpperBoundEqualTo:Boolean, + var lowerBound:Array[Byte], var isLowerBoundEqualTo:Boolean) extends Serializable { + + def mergeIntersect(other:ScanRange): Unit = { + val upperBoundCompare = compareRange(upperBound, other.upperBound) + val lowerBoundCompare = compareRange(lowerBound, other.lowerBound) + + upperBound = if (upperBoundCompare <0) upperBound else other.upperBound + lowerBound = if (lowerBoundCompare >0) lowerBound else other.lowerBound + + isLowerBoundEqualTo = if (lowerBoundCompare == 0) + isLowerBoundEqualTo && other.isLowerBoundEqualTo + else isLowerBoundEqualTo + + isUpperBoundEqualTo = if (upperBoundCompare == 0) + isUpperBoundEqualTo && other.isUpperBoundEqualTo + else isUpperBoundEqualTo + } + + def mergeUnion(other:ScanRange): Unit = { + + val upperBoundCompare = compareRange(upperBound, other.upperBound) + val lowerBoundCompare = compareRange(lowerBound, other.lowerBound) + + upperBound = if (upperBoundCompare >0) upperBound else other.upperBound + lowerBound = if (lowerBoundCompare <0) lowerBound else other.lowerBound + + isLowerBoundEqualTo = if (lowerBoundCompare == 0) + isLowerBoundEqualTo || other.isLowerBoundEqualTo + else isLowerBoundEqualTo + + isUpperBoundEqualTo = if (upperBoundCompare == 0) + isUpperBoundEqualTo || other.isUpperBoundEqualTo + else isUpperBoundEqualTo + } + + def doesOverLap(other:ScanRange): Boolean = { + if (compareRange(other.upperBound, lowerBound) >= 0 || + compareRange(other.lowerBound, upperBound) >= 0){ + true + } else { + false + } + } + + def compareRange(left:Array[Byte], right:Array[Byte]): Int = { + if (left == null && right == null) 0 + else if (left == null && right != null) 1 + else if (left != null && right == null) -1 + else Bytes.compareTo(left, right) + } + + override def toString():String = { + "ScanRange:(" + Bytes.toString(upperBound) + "," + isUpperBoundEqualTo + "," + + Bytes.toString(lowerBound) + "," + isLowerBoundEqualTo + ")" + } +} + +class ColumnFilter (var currentPoint:Array[Byte] = null, + var currentRange:ScanRange = null) extends Serializable { + var ranges = new mutable.MutableList[ScanRange]() + if (currentRange != null ) ranges.+=(currentRange) + + var points = new mutable.MutableList[Array[Byte]]() + if (currentPoint != null) { + points.+=(currentPoint) + } + + def validate(value:Array[Byte]):Boolean = { + var result = false + + points.foreach( p => { + if (Bytes.equals(p, value)) { + result = true + } + }) + + ranges.foreach( r => { + val upperBoundPass = r.upperBound == null || + (r.isUpperBoundEqualTo && Bytes.compareTo(r.upperBound, value) >= 0) || + (!r.isUpperBoundEqualTo && Bytes.compareTo(r.upperBound, value) > 0) + val lowerBoundPass = r.lowerBound == null || r.lowerBound.size == 0 + (r.isLowerBoundEqualTo && Bytes.compareTo(r.lowerBound, value) <= 0) || + (!r.isLowerBoundEqualTo && Bytes.compareTo(r.lowerBound, value) < 0) + + println("Filter: " + Bytes.toString(value) + ":" + upperBoundPass + "," + lowerBoundPass + " " + result) + if (r.upperBound != null) println(" upper: " + Bytes.toString(r.upperBound)) + if (r.lowerBound != null) println(" lower: " + Bytes.toString(r.lowerBound)) + + result = result || (upperBoundPass && lowerBoundPass) + }) + result + } + + def orAppend(other:ColumnFilter): Unit = { + other.points.foreach( p => points += p) + + other.ranges.foreach( otherR => { + var doesOverLap = false + ranges.foreach{ r => + if (r.doesOverLap(otherR)) { + r.mergeUnion(otherR) + doesOverLap = true + }} + if (!doesOverLap) ranges.+=(otherR) + }) + } + + def andAppend(other:ColumnFilter): Unit = { + val survivingPoints = new mutable.MutableList[Array[Byte]]() + points.foreach( p => { + other.points.foreach( otherP => { + if (Bytes.equals(p, otherP)) { + survivingPoints.+=(p) + } + }) + }) + points = survivingPoints + + val survivingRanges = new mutable.MutableList[ScanRange]() + + other.ranges.foreach( otherR => { + ranges.foreach( r => { + if (r.doesOverLap(otherR)) { + r.mergeIntersect(otherR) + survivingRanges += r + } + }) + }) + ranges = survivingRanges + } + + override def toString:String = { + val strBuilder = new StringBuilder + strBuilder.append("(points:(") + var isFirst = true + points.foreach( p => { + if (isFirst) isFirst = false + else strBuilder.append(",") + strBuilder.append(Bytes.toString(p)) + }) + strBuilder.append("),ranges:") + isFirst = true + ranges.foreach( r => { + if (isFirst) isFirst = false + else strBuilder.append(",") + strBuilder.append(r) + }) + strBuilder.append("))") + strBuilder.toString() + } +} + +class ColumnFilterCollection { + val columnFilterMap = new mutable.HashMap[String, ColumnFilter] + + def orAppend(column:String, columnFilter:ColumnFilter): Unit = { + val existingFilter = columnFilterMap.get(column) + if (existingFilter.isEmpty) { + columnFilterMap.+=((column, columnFilter)) + } else { + existingFilter.get.orAppend(columnFilter) + } + } + + def orAppend(columnFilterColumn:ColumnFilterCollection): Unit = { + columnFilterColumn.columnFilterMap.foreach( e => { + orAppend(e._1, e._2) + }) + } + + def andAppend(columnFilterColumn:ColumnFilterCollection): Unit = { + columnFilterColumn.columnFilterMap.foreach( e => { + val existingColumnFilter = columnFilterMap.get(e._1) + if (existingColumnFilter.isEmpty) { + columnFilterMap += e + } else { + existingColumnFilter.get.andAppend(e._2) + } + }) + } + + def generateFamilyQualifiterFilterMap(schemaDefinitionMap: + java.util.HashMap[String, SchemaQualifierDefinition]): + util.HashMap[ColumnFamilyQualifierWrapper, ColumnFilter] = { + val familyQualifierFilterMap = new util.HashMap[ColumnFamilyQualifierWrapper, ColumnFilter]() + columnFilterMap.foreach( e => { + val definition = schemaDefinitionMap.get(e._1) + //Don't add rowKeyFilter + if (definition.columnFamilyBytes.size > 0) { + familyQualifierFilterMap.put( + new ColumnFamilyQualifierWrapper(definition.columnFamilyBytes, definition.qualifierBytes), e._2) + } + + }) + familyQualifierFilterMap + } + + override def toString:String = { + val strBuilder = new StringBuilder + columnFilterMap.foreach( e => strBuilder.append(e)) + strBuilder.toString() + } +} + +class ColumnFamilyQualifierWrapper(val columnFamily:Array[Byte], val qualifier:Array[Byte]) + extends Serializable{ + + override def equals(other:Any): Boolean = { + if (other.isInstanceOf[ColumnFamilyQualifierWrapper]) { + val otherWrapper = other.asInstanceOf[ColumnFamilyQualifierWrapper] + Bytes.compareTo(columnFamily, otherWrapper.columnFamily) == 0 && + Bytes.compareTo(qualifier, otherWrapper.qualifier) == 0 + } else { + false + } + } + + override def hashCode():Int = { + Bytes.hashCode(columnFamily) + Bytes.hashCode(qualifier) + } +} + +object Utils { + def getValue(columnName: String, + schemaMappingDefinition: java.util.HashMap[String, SchemaQualifierDefinition], + r: Result): Any = { + + val columnDef = schemaMappingDefinition.get(columnName) + + if (columnDef == null) throw new Throwable("Unknown column:" + columnName) + + + if (columnDef.columnFamilyBytes.isEmpty) { + val roKey = r.getRow + + columnDef.columnSparkSqlType match { + case IntegerType => Bytes.toInt(roKey) + case LongType => Bytes.toLong(roKey) + case FloatType => Bytes.toFloat(roKey) + case DoubleType => Bytes.toDouble(roKey) + case StringType => Bytes.toString(roKey) + case TimestampType => new Timestamp(Bytes.toLong(roKey)) + case _ => Bytes.toString(roKey) + } + } else { + val cellByteValue = + r.getColumnLatestCell(columnDef.columnFamilyBytes, columnDef.qualifierBytes) + if (cellByteValue == null) null + else columnDef.columnSparkSqlType match { + case IntegerType => Bytes.toInt(cellByteValue.getValueArray, + cellByteValue.getValueOffset, cellByteValue.getValueLength) + case LongType => Bytes.toLong(cellByteValue.getValueArray, + cellByteValue.getValueOffset, cellByteValue.getValueLength) + case FloatType => Bytes.toFloat(cellByteValue.getValueArray, + cellByteValue.getValueOffset) + case DoubleType => Bytes.toDouble(cellByteValue.getValueArray, + cellByteValue.getValueOffset) + case StringType => Bytes.toString(cellByteValue.getValueArray, + cellByteValue.getValueOffset, cellByteValue.getValueLength) + case TimestampType => new Timestamp(Bytes.toLong(cellByteValue.getValueArray, + cellByteValue.getValueOffset, cellByteValue.getValueLength)) + case _ => Bytes.toString(cellByteValue.getValueArray, + cellByteValue.getValueOffset, cellByteValue.getValueLength) + } + } + } + def getByteValue(columnName: String, + schemaMappingDefinition: mutable.Map[String, SchemaQualifierDefinition], + value: String): Array[Byte] = { + + val columnDef = schemaMappingDefinition.get(columnName) + + if (columnDef == null) throw new Throwable("Unknown column:" + columnName) + + if (columnDef.isEmpty) { throw new Throwable("Unknown column:" + columnName)} + else { + columnDef.get.columnSparkSqlType match { + case IntegerType => Bytes.toBytes(value.toInt) + case LongType => Bytes.toBytes(value.toLong) + case FloatType => Bytes.toBytes(value.toFloat) + case DoubleType => Bytes.toBytes(value.toDouble) + case StringType => Bytes.toBytes(value) + case TimestampType => Bytes.toBytes(value.toLong) + case _ => Bytes.toBytes(value) + } + } + } +} \ No newline at end of file diff --git a/hbase-spark/src/main/scala/org/apache/hadoop/hbase/spark/HBaseContext.scala b/hbase-spark/src/main/scala/org/apache/hadoop/hbase/spark/HBaseContext.scala index f060fea..ab4dbf5 100644 --- a/hbase-spark/src/main/scala/org/apache/hadoop/hbase/spark/HBaseContext.scala +++ b/hbase-spark/src/main/scala/org/apache/hadoop/hbase/spark/HBaseContext.scala @@ -64,6 +64,8 @@ class HBaseContext(@transient sc: SparkContext, val broadcastedConf = sc.broadcast(new SerializableWritable(config)) val credentialsConf = sc.broadcast(new SerializableWritable(job.getCredentials)) + LatestHBaseContextCache.latest = this + if (tmpHdfsConfgFile != null && config != null) { val fs = FileSystem.newInstance(config) val tmpPath = new Path(tmpHdfsConfgFile) @@ -568,3 +570,7 @@ class HBaseContext(@transient sc: SparkContext, private[spark] def fakeClassTag[T]: ClassTag[T] = ClassTag.AnyRef.asInstanceOf[ClassTag[T]] } + +object LatestHBaseContextCache { + var latest:HBaseContext = null +} diff --git a/hbase-spark/src/test/scala/org/apache/hadoop/hbase/spark/DefaultSourceSuite.scala b/hbase-spark/src/test/scala/org/apache/hadoop/hbase/spark/DefaultSourceSuite.scala new file mode 100644 index 0000000..6b99019 --- /dev/null +++ b/hbase-spark/src/test/scala/org/apache/hadoop/hbase/spark/DefaultSourceSuite.scala @@ -0,0 +1,158 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.hadoop.hbase.spark + +import org.apache.hadoop.hbase.client.{Result, Get, Put, ConnectionFactory} +import org.apache.hadoop.hbase.util.Bytes +import org.apache.hadoop.hbase.{TableName, HBaseTestingUtility} +import org.apache.spark.sql.SQLContext +import org.apache.spark.{SparkContext, Logging} +import org.scalatest.{BeforeAndAfterAll, BeforeAndAfterEach, FunSuite} + +class DefaultSourceSuite extends FunSuite with +BeforeAndAfterEach with BeforeAndAfterAll with Logging { + @transient var sc: SparkContext = null + var TEST_UTIL: HBaseTestingUtility = new HBaseTestingUtility + + val tableName = "t1" + val columnFamily = "c" + + override def beforeAll() { + + TEST_UTIL.startMiniCluster + + logInfo(" - minicluster started") + try + TEST_UTIL.deleteTable(TableName.valueOf(tableName)) + catch { + case e: Exception => logInfo(" - no table " + tableName + " found") + + } + logInfo(" - creating table " + tableName) + TEST_UTIL.createTable(TableName.valueOf(tableName), Bytes.toBytes(columnFamily)) + logInfo(" - created table") + + sc = new SparkContext("local", "test") + } + + override def afterAll() { + TEST_UTIL.deleteTable(TableName.valueOf(tableName)) + logInfo("shuting down minicluster") + TEST_UTIL.shutdownMiniCluster() + + sc.stop() + } + + test("dataframe.select test") { + val config = TEST_UTIL.getConfiguration + val connection = ConnectionFactory.createConnection(config) + val table = connection.getTable(TableName.valueOf("t1")) + + try { + var put = new Put(Bytes.toBytes("get1")) + put.addColumn(Bytes.toBytes(columnFamily), Bytes.toBytes("a"), Bytes.toBytes("foo1")) + put.addColumn(Bytes.toBytes(columnFamily), Bytes.toBytes("b"), Bytes.toBytes("1")) + table.put(put) + put = new Put(Bytes.toBytes("get2")) + put.addColumn(Bytes.toBytes(columnFamily), Bytes.toBytes("a"), Bytes.toBytes("foo2")) + put.addColumn(Bytes.toBytes(columnFamily), Bytes.toBytes("b"), Bytes.toBytes("4")) + table.put(put) + put = new Put(Bytes.toBytes("get3")) + put.addColumn(Bytes.toBytes(columnFamily), Bytes.toBytes("a"), Bytes.toBytes("foo3")) + put.addColumn(Bytes.toBytes(columnFamily), Bytes.toBytes("b"), Bytes.toBytes("8")) + table.put(put) + put = new Put(Bytes.toBytes("get4")) + put.addColumn(Bytes.toBytes(columnFamily), Bytes.toBytes("a"), Bytes.toBytes("foo4")) + put.addColumn(Bytes.toBytes(columnFamily), Bytes.toBytes("b"), Bytes.toBytes("10")) + table.put(put) + } finally { + table.close() + connection.close() + } + + val hbaseContext = new HBaseContext(sc, config) + val sqlContext = new SQLContext(sc) + + val df = sqlContext.load("org.apache.hadoop.hbase.spark", + Map("hbase.columns.mapping" -> "KEY_FIELD STRING :key, A_FIELD STRING c:a, B_FIELD STRING c:b,", + "hbase.table" -> "t1")) + + println("simple select") + //df.select("KEY_FIELD").foreach(r => println(" - " + r)) + + println("tempTable or test") + df.registerTempTable("hbaseTmp") + + sqlContext.sql("SELECT KEY_FIELD FROM hbaseTmp " + + "WHERE " + + "(KEY_FIELD = 'get1' and B_FIELD < '3') or " + + "(KEY_FIELD <= 'get3' and B_FIELD = '8')").foreach(r => println(" - " + r)) + + println("------------------------") + + sqlContext.sql("SELECT KEY_FIELD FROM hbaseTmp " + + "WHERE " + + "(KEY_FIELD = 'get1' and B_FIELD < '3') or " + + "(KEY_FIELD < 'get3' and B_FIELD = '8')").foreach(r => println(" - " + r)) + + println("------------------------") + + sqlContext.sql("SELECT KEY_FIELD FROM hbaseTmp " + + "WHERE " + + "(KEY_FIELD > 'get1')").foreach(r => println(" - " + r)) + println("------------------------") + + sqlContext.sql("SELECT KEY_FIELD FROM hbaseTmp " + + "WHERE " + + "(KEY_FIELD >= 'get1')").foreach(r => println(" - " + r)) + + println("------------------------") + + sqlContext.sql("SELECT KEY_FIELD FROM hbaseTmp " + + "WHERE " + + "(KEY_FIELD = 'get2' or KEY_FIELD = 'get1') and " + + "(B_FIELD < '3' or B_FIELD = '4')").foreach(r => println(" - " + r)) + + println("------------------------") + + sqlContext.sql("SELECT KEY_FIELD FROM hbaseTmp " + + "WHERE " + + "(KEY_FIELD > 'get1' and KEY_FIELD < 'get3') or " + + "(KEY_FIELD <= 'get4' and B_FIELD = '8')").foreach(r => println(" - " + r)) + + + sqlContext.sql("SELECT KEY_FIELD FROM hbaseTmp " + + "WHERE " + + "(A_FIELD = 'foo1' and B_FIELD < '3') or " + + "(A_FIELD < 'foo3' and B_FIELD = '8')").foreach(r => println(" - " + r)) + + println("------------------------") + + sqlContext.sql("SELECT KEY_FIELD FROM hbaseTmp " + + "WHERE " + + "(A_FIELD = 'foo2' or A_FIELD = 'foo1') and " + + "(B_FIELD < '3' or B_FIELD = '4')").foreach(r => println(" - " + r)) + + println("------------------------") + + sqlContext.sql("SELECT KEY_FIELD FROM hbaseTmp " + + "WHERE " + + "B_FIELD < '3' or " + + " B_FIELD <= '8'").foreach(r => println(" - " + r)) + } +}