diff --git a/hbase-spark/pom.xml b/hbase-spark/pom.xml index e48f9e8..7417127 100644 --- a/hbase-spark/pom.xml +++ b/hbase-spark/pom.xml @@ -79,6 +79,12 @@ org.apache.spark + spark-sql_${scala.binary.version} + ${spark.version} + provided + + + org.apache.spark spark-streaming_${scala.binary.version} ${spark.version} diff --git a/hbase-spark/src/main/scala/org/apache/hadoop/hbase/spark/DefaultSource.scala b/hbase-spark/src/main/scala/org/apache/hadoop/hbase/spark/DefaultSource.scala new file mode 100644 index 0000000..13e33e7 --- /dev/null +++ b/hbase-spark/src/main/scala/org/apache/hadoop/hbase/spark/DefaultSource.scala @@ -0,0 +1,326 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.hadoop.hbase.spark + +import java.sql.Timestamp + +import org.apache.hadoop.hbase.client.{Result, Scan} +import org.apache.hadoop.hbase.util.Bytes +import org.apache.hadoop.hbase.{TableName, HBaseConfiguration} +import org.apache.spark.Logging +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.{Row, SQLContext} +import org.apache.spark.sql.sources._ +import org.apache.spark.sql.types._ + +import scala.collection.mutable + +class DefaultSource extends RelationProvider { + + val TABLE_KEY:String = "hbase.table" + val SCHEMA_COLUMNS_MAPPING_KEY:String = "hbase.columns.mapping" + val BATCHING_NUM_KEY:String = "hbase.batching.num" + val CATCHING_NUM_KEY:String = "hbase.catching.num" + val HBASE_CONFIG_RESOURCES_LOCATIONS:String = "hbase.config.resources" + val USE_HBASE_CONTEXT:String = "hbase.use.hbase.context" + + override def createRelation(sqlContext: SQLContext, + parameters: Map[String, String]): + BaseRelation = { + + println("baseRelations") + + val tableName = parameters.getOrElse(TABLE_KEY, "") + val schemaMappingString = parameters.getOrElse(SCHEMA_COLUMNS_MAPPING_KEY, "") + val batchingNumStr = parameters.getOrElse(BATCHING_NUM_KEY, "1000") + val catchingNumStr = parameters.getOrElse(CATCHING_NUM_KEY, "1000") + val hbaseConfigResources = parameters.getOrElse(HBASE_CONFIG_RESOURCES_LOCATIONS, "") + val useHBaseReources = parameters.getOrElse(USE_HBASE_CONTEXT, "true") + + if (tableName.isEmpty) { + new Throwable("Invalid value for " + TABLE_KEY +" '" + tableName + "'") + } + + val batchingNum:Int = try { + batchingNumStr.toInt + } catch { + case e:Exception => throw + new Throwable("Invalid value for " + BATCHING_NUM_KEY +" '" + batchingNumStr + "'", e ) + } + + val catchingNum:Int = try { + catchingNumStr.toInt + } catch { + case e:Exception => throw + new Throwable("Invalid value for " + CATCHING_NUM_KEY +" '" + catchingNumStr + "'", e ) + } + + new HBaseRelation(tableName, + generateSchemaMappingMap(schemaMappingString), + batchingNum.toInt, + catchingNum.toInt, + hbaseConfigResources, + useHBaseReources.equalsIgnoreCase("true"))(sqlContext) + } + + def generateSchemaMappingMap(schemaMappingString:String): mutable.Map[String, SchemaMapDefinition] = { + try { + val columnDefinitions = schemaMappingString.split(',') + val resultingMap = new mutable.HashMap[String, SchemaMapDefinition]() + columnDefinitions.map(cd => { + val parts = cd.trim.split(' ') + val hbaseDefinitionParts = if (parts(2).charAt(0) == ':') { + Array[String]("", "key") + } else { + parts(2).split(':') + } + resultingMap.+=((parts(0), new SchemaMapDefinition(parts(0), + parts(1), hbaseDefinitionParts(0), hbaseDefinitionParts(1)))) + }) + resultingMap + } catch { + case e:Exception => throw + new Throwable("Invalid value for " + SCHEMA_COLUMNS_MAPPING_KEY + + " '" + schemaMappingString + "'", e ) + } + } +} + +class HBaseRelation (table:String, + schemaMappingDefinition:mutable.Map[String, SchemaMapDefinition], + batchingNum:Int, + cachingNum:Int, + configResources:String, + useHBaseContext:Boolean) ( + @transient val sqlContext:SQLContext) + extends BaseRelation with PrunedFilteredScan with Logging { + + //create or get latest HBaseContext + @transient val hbaseContext:HBaseContext = if (useHBaseContext) { + LatestHBaseContextCache.latest + } else { + val config = HBaseConfiguration.create() + configResources.split(",").foreach( r => config.addResource(r)) + new HBaseContext(sqlContext.sparkContext, config) + } + + override def schema: StructType = { + println("schema") + val result = new StructType(schemaMappingDefinition.values.map(c => { + println(" - columnName:" + c.columnName + " c.columnSparkSqlType:" + c.columnSparkSqlType) + val metadata = new MetadataBuilder().putString("name", c.columnName) + + new StructField(c.columnName, c.columnSparkSqlType, nullable = true, metadata.build()) + }).toArray) + result.map(f => { + println(f.name, f.dataType, f.nullable, f.metadata) + }) + result + } + + override def buildScan(requiredColumns: Array[String], filters: Array[Filter]): RDD[Row] = { + println("buildScan") + filters.foreach(f => { + println("- Start Root Filter") + printlnOutFilter(f) + println("- Finished Root Filter") + }) + + val scan = new Scan() + scan.setBatch(batchingNum) + scan.setCaching(cachingNum) + val serializableMap = new java.util.HashMap[String, SchemaMapDefinition] + schemaMappingDefinition.foreach( e => serializableMap.put(e._1, e._2)) + hbaseContext.hbaseRDD(TableName.valueOf(table), scan).map( r => { + val rowArray = requiredColumns.map(c => RelationStaticFunctions.getValue(c, serializableMap, r._2)) + Row.fromSeq(rowArray) + }) + } + + + + def printlnOutFilter(f: Filter): Unit = { + f match { + case EqualTo(attr, value) => println("EqualTo", attr, value) + case LessThan(attr, value) => println("LessThen", attr, value) + case GreaterThan(attr, value) => println("GreaterThen", attr, value) + case LessThanOrEqual(attr, value) => println("LessThenOrEqual", attr, value) + case GreaterThanOrEqual(attr, value) => println("GreateThenOrEqual", attr, value) + case Or(left, right) => + printlnOutFilter(left) + println(" OR ") + printlnOutFilter(right) + case And(left, right) => + printlnOutFilter(left) + println(" AND ") + printlnOutFilter(right) + case _ => + println("Skipping filter: " + f) + } + } + +} + +object RelationStaticFunctions { + def getValue(columnName:String, + schemaMappingDefinition:java.util.HashMap[String, SchemaMapDefinition], + r:Result): Any = { + + val columnDef = schemaMappingDefinition.get(columnName) + + if (columnDef == null) throw new Throwable("Unknown column:" + columnName) + + + if (columnDef.columnFamilyBytes.isEmpty) { + val roKey = r.getRow + + columnDef.columnSparkSqlType match { + case IntegerType => Bytes.toInt(roKey) + case LongType => Bytes.toLong(roKey) + case FloatType => Bytes.toFloat(roKey) + case DoubleType => Bytes.toDouble(roKey) + case StringType => Bytes.toString(roKey) + case TimestampType => new Timestamp(Bytes.toLong(roKey)) + case _ => Bytes.toString(roKey) + } + } else { + val cellByteValue = + r.getColumnLatestCell(columnDef.columnFamilyBytes, columnDef.qualifierBytes) + if (cellByteValue == null) null + else columnDef.columnSparkSqlType match { + case IntegerType => Bytes.toInt(cellByteValue.getValueArray, + cellByteValue.getValueOffset, cellByteValue.getValueLength) + case LongType => Bytes.toLong(cellByteValue.getValueArray, + cellByteValue.getValueOffset, cellByteValue.getValueLength) + case FloatType => Bytes.toFloat(cellByteValue.getValueArray, + cellByteValue.getValueOffset) + case DoubleType => Bytes.toDouble(cellByteValue.getValueArray, + cellByteValue.getValueOffset) + case StringType => Bytes.toString(cellByteValue.getValueArray, + cellByteValue.getValueOffset, cellByteValue.getValueLength) + case TimestampType => new Timestamp(Bytes.toLong(cellByteValue.getValueArray, + cellByteValue.getValueOffset, cellByteValue.getValueLength)) + case _ => Bytes.toString(cellByteValue.getValueArray, + cellByteValue.getValueOffset, cellByteValue.getValueLength) + } + } + } +} + +case class SchemaMapDefinition(columnName:String, + colType:String, + columnFamily:String, + qualifier:String) extends Serializable { + val columnFamilyBytes = Bytes.toBytes(columnFamily) + val qualifierBytes = Bytes.toBytes(qualifier) + val columnSparkSqlType = if (colType.equals("BOOLEAN")) BooleanType + else if (colType.equals("TINYINT")) IntegerType + else if (colType.equals("INT")) IntegerType + else if (colType.equals("BIGINT")) LongType + else if (colType.equals("FLOAT")) FloatType + else if (colType.equals("DOUBLE")) DoubleType + else if (colType.equals("STRING")) StringType + else if (colType.equals("TIMESTAMP")) TimestampType + else if (colType.equals("DECIMAL")) StringType //DataTypes.createDecimalType(precision, scale) + else throw new Throwable("Unsupported column type :" + colType) + +} + + //TODO add in filter logic for both rowkey and column values + /* + def compileFilter(f: Filter, scanFilterMap:mutable.HashMap[String, ScanFilters]): Unit = { + f match { + case EqualTo(attr, value) => scanFilterMap.getOrElseUpdate(attr, + new ScanFilters).addPoint(value.toString) + case LessThan(attr, value) => scanFilterMap.getOrElseUpdate(attr, + new ScanFilters).addLessThen(value.toString, false) + case GreaterThan(attr, value) => scanFilterMap.getOrElseUpdate(attr, + new ScanFilters).addGreaterThen(value.toString, true) + case LessThanOrEqual(attr, value) => scanFilterMap.getOrElseUpdate(attr, + new ScanFilters).addLessThen(value.toString, false) + case GreaterThanOrEqual(attr, value) => scanFilterMap.getOrElseUpdate(attr, + new ScanFilters).addGreaterThen(value.toString, true) + case Or(left, right) => + compileFilter(left, scanFilterMap) + + val scanFilterMapRightSide = mutable.HashMap()[String, ScanFilters] + compileFilter(right, scanFilterMapRightSide) + + scanFilterMapRightSide.foreach( e => { + val leftSide = scanFilterMap.get(e._1) + if (leftSide.isEmpty) scanFilterMap.put(e._1, e._2) + else leftSide.get.orMerge(e._2) + }) + case And(left, right) => + compileFilter(left, scanFilterMap) + compileFilter(right, scanFilterMap) + case _ => + logWarning("Skipping filter: " + f) + } + } +} + +class ScanFilters() { + val ranges = new mutable.MutableList()[ScanRange] + var currentRange = new ScanRange(null, true, null, true) + ranges.+=(currentRange) + val points = new mutable.MutableList()[String] + var currentPoint:String = null + + def addGreaterThen(newGreaterThen:String, isEqualTo:Boolean): Unit = { + if (currentRange.upperBound == null) { + currentRange.upperBound = newGreaterThen + currentRange.isUpperBoundEqualTo = isEqualTo + } else if (currentRange.upperBound.compareTo(newGreaterThen) > 0) { + currentRange.upperBound = newGreaterThen + currentRange.isUpperBoundEqualTo = isEqualTo + } else if (currentRange.upperBound.equals(newGreaterThen)) { + currentRange.isUpperBoundEqualTo = currentRange.isUpperBoundEqualTo && isEqualTo + } + } + + def addLessThen(newLessThen:String, isEqualTo:Boolean): Unit = { + if (currentRange.lowerBound == null) { + currentRange.lowerBound = newLessThen + currentRange.isLowerBoundEqualTo = isEqualTo + } else if (currentRange.lowerBound.compareTo(newLessThen) < 0) { + currentRange.lowerBound = newLessThen + currentRange.isLowerBoundEqualTo = isEqualTo + } else if (currentRange.lowerBound.equals(newLessThen)) { + currentRange.isLowerBoundEqualTo = currentRange.isLowerBoundEqualTo && isEqualTo + } + } + + def addPoint(newPoint:String): Unit = { + if (currentPoint == null) { + currentPoint = newPoint + } else { + currentPoint = "" + } + } + + def orMerge(o:ScanFilters): Unit = { + println("OrMerge") + + } +} + +class ScanRange(var upperBound:String, var isUpperBoundEqualTo:Boolean, + var lowerBound:String, var isLowerBoundEqualTo:Boolean) +*/ + diff --git a/hbase-spark/src/main/scala/org/apache/hadoop/hbase/spark/HBaseContext.scala b/hbase-spark/src/main/scala/org/apache/hadoop/hbase/spark/HBaseContext.scala index f060fea..ab4dbf5 100644 --- a/hbase-spark/src/main/scala/org/apache/hadoop/hbase/spark/HBaseContext.scala +++ b/hbase-spark/src/main/scala/org/apache/hadoop/hbase/spark/HBaseContext.scala @@ -64,6 +64,8 @@ class HBaseContext(@transient sc: SparkContext, val broadcastedConf = sc.broadcast(new SerializableWritable(config)) val credentialsConf = sc.broadcast(new SerializableWritable(job.getCredentials)) + LatestHBaseContextCache.latest = this + if (tmpHdfsConfgFile != null && config != null) { val fs = FileSystem.newInstance(config) val tmpPath = new Path(tmpHdfsConfgFile) @@ -568,3 +570,7 @@ class HBaseContext(@transient sc: SparkContext, private[spark] def fakeClassTag[T]: ClassTag[T] = ClassTag.AnyRef.asInstanceOf[ClassTag[T]] } + +object LatestHBaseContextCache { + var latest:HBaseContext = null +} diff --git a/hbase-spark/src/test/scala/org/apache/hadoop/hbase/spark/DefaultSourceSuite.scala b/hbase-spark/src/test/scala/org/apache/hadoop/hbase/spark/DefaultSourceSuite.scala new file mode 100644 index 0000000..454a517 --- /dev/null +++ b/hbase-spark/src/test/scala/org/apache/hadoop/hbase/spark/DefaultSourceSuite.scala @@ -0,0 +1,90 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.hadoop.hbase.spark + +import org.apache.hadoop.hbase.client.{Result, Get, Put, ConnectionFactory} +import org.apache.hadoop.hbase.util.Bytes +import org.apache.hadoop.hbase.{TableName, HBaseTestingUtility} +import org.apache.spark.sql.SQLContext +import org.apache.spark.{SparkContext, Logging} +import org.scalatest.{BeforeAndAfterAll, BeforeAndAfterEach, FunSuite} + +class DefaultSourceSuite extends FunSuite with +BeforeAndAfterEach with BeforeAndAfterAll with Logging { + @transient var sc: SparkContext = null + var TEST_UTIL: HBaseTestingUtility = new HBaseTestingUtility + + val tableName = "t1" + val columnFamily = "c" + + override def beforeAll() { + + TEST_UTIL.startMiniCluster + + logInfo(" - minicluster started") + try + TEST_UTIL.deleteTable(TableName.valueOf(tableName)) + catch { + case e: Exception => logInfo(" - no table " + tableName + " found") + + } + logInfo(" - creating table " + tableName) + TEST_UTIL.createTable(TableName.valueOf(tableName), Bytes.toBytes(columnFamily)) + logInfo(" - created table") + + sc = new SparkContext("local", "test") + } + + override def afterAll() { + TEST_UTIL.deleteTable(TableName.valueOf(tableName)) + logInfo("shuting down minicluster") + TEST_UTIL.shutdownMiniCluster() + + sc.stop() + } + + test("bulkGet to test HBase client") { + val config = TEST_UTIL.getConfiguration + val connection = ConnectionFactory.createConnection(config) + val table = connection.getTable(TableName.valueOf("t1")) + + try { + var put = new Put(Bytes.toBytes("get1")) + put.addColumn(Bytes.toBytes(columnFamily), Bytes.toBytes("a"), Bytes.toBytes("foo1")) + table.put(put) + put = new Put(Bytes.toBytes("get2")) + put.addColumn(Bytes.toBytes(columnFamily), Bytes.toBytes("a"), Bytes.toBytes("foo2")) + table.put(put) + put = new Put(Bytes.toBytes("get3")) + put.addColumn(Bytes.toBytes(columnFamily), Bytes.toBytes("a"), Bytes.toBytes("foo3")) + table.put(put) + } finally { + table.close() + connection.close() + } + + val hbaseContext = new HBaseContext(sc, config) + val sqlContext = new SQLContext(sc) + + val df = sqlContext.load("org.apache.hadoop.hbase.spark", + Map("hbase.columns.mapping" -> "KEY_FIELD STRING :key, A_FIELD STRING c:a", + "hbase.table" -> "t1")) + + df.select("KEY_FIELD").foreach(r => println(r)) + } +}