diff --git a/hbase-spark/src/main/scala/org/apache/hadoop/hbase/spark/DefaultSource.scala b/hbase-spark/src/main/scala/org/apache/hadoop/hbase/spark/DefaultSource.scala index 469069d..8dff041 100644 --- a/hbase-spark/src/main/scala/org/apache/hadoop/hbase/spark/DefaultSource.scala +++ b/hbase-spark/src/main/scala/org/apache/hadoop/hbase/spark/DefaultSource.scala @@ -21,17 +21,17 @@ import java.util import java.util.concurrent.ConcurrentLinkedQueue import org.apache.hadoop.hbase.client._ -import org.apache.hadoop.hbase.spark.datasources.HBaseSparkConf -import org.apache.hadoop.hbase.spark.datasources.HBaseTableScanRDD -import org.apache.hadoop.hbase.spark.datasources.SerializableConfiguration +import org.apache.hadoop.hbase.io.ImmutableBytesWritable +import org.apache.hadoop.hbase.mapred.TableOutputFormat +import org.apache.hadoop.hbase.spark.datasources.{Utils, HBaseSparkConf, HBaseTableScanRDD, SerializableConfiguration} import org.apache.hadoop.hbase.types._ import org.apache.hadoop.hbase.util.{Bytes, PositionedByteRange, SimplePositionedMutableByteRange} -import org.apache.hadoop.hbase.{HBaseConfiguration, TableName} +import org.apache.hadoop.hbase.{HColumnDescriptor, HTableDescriptor, HBaseConfiguration, TableName} +import org.apache.hadoop.mapred.JobConf import org.apache.spark.Logging import org.apache.spark.rdd.RDD import org.apache.spark.sql.datasources.hbase.{Field, HBaseTableCatalog} -import org.apache.spark.sql.types.{DataType => SparkDataType} -import org.apache.spark.sql.{Row, SQLContext} +import org.apache.spark.sql.{DataFrame, SaveMode, Row, SQLContext} import org.apache.spark.sql.sources._ import org.apache.spark.sql.types._ @@ -48,10 +48,11 @@ import scala.collection.mutable * - Type conversions of basic SQL types. All conversions will be * Through the HBase Bytes object commands. */ -class DefaultSource extends RelationProvider with Logging { +class DefaultSource extends RelationProvider with CreatableRelationProvider with Logging { /** * Is given input from SparkSQL to construct a BaseRelation - * @param sqlContext SparkSQL context + * + * @param sqlContext SparkSQL context * @param parameters Parameters given to us from SparkSQL * @return A BaseRelation Object */ @@ -60,18 +61,31 @@ class DefaultSource extends RelationProvider with Logging { BaseRelation = { new HBaseRelation(parameters, None)(sqlContext) } + + + override def createRelation( + sqlContext: SQLContext, + mode: SaveMode, + parameters: Map[String, String], + data: DataFrame): BaseRelation = { + val relation = HBaseRelation(parameters, Some(data.schema))(sqlContext) + relation.createTable() + relation.insert(data, false) + relation + } } /** * Implementation of Spark BaseRelation that will build up our scan logic * , do the scan pruning, filter push down, and value conversions - * @param sqlContext SparkSQL context + * + * @param sqlContext SparkSQL context */ case class HBaseRelation ( @transient parameters: Map[String, String], userSpecifiedSchema: Option[StructType] )(@transient val sqlContext: SQLContext) - extends BaseRelation with PrunedFilteredScan with Logging { + extends BaseRelation with PrunedFilteredScan with InsertableRelation with Logging { val catalog = HBaseTableCatalog(parameters) def tableName = catalog.name val configResources = parameters.getOrElse(HBaseSparkConf.HBASE_CONFIG_RESOURCES_LOCATIONS, "") @@ -116,6 +130,84 @@ case class HBaseRelation ( */ override val schema: StructType = userSpecifiedSchema.getOrElse(catalog.toDataType) + + + def createTable() { + if (catalog.numReg > 3) { + val tName = TableName.valueOf(catalog.name) + val cfs = catalog.getColumnFamilies + val connection = ConnectionFactory.createConnection(hbaseConf) + // Initialize hBase table if necessary + val admin = connection.getAdmin() + if (!admin.isTableAvailable(tName)) { + val tableDesc = new HTableDescriptor(tName) + cfs.foreach { x => + val cf = new HColumnDescriptor(x.getBytes()) + logDebug(s"add family $x to ${catalog.name}") + tableDesc.addFamily(cf) + } + val startKey = Bytes.toBytes("aaaaaaa"); + val endKey = Bytes.toBytes("zzzzzzz"); + val splitKeys = Bytes.split(startKey, endKey, catalog.numReg - 3); + admin.createTable(tableDesc, splitKeys) + val r = connection.getRegionLocator(TableName.valueOf(catalog.name)).getAllRegionLocations + while(r == null || r.size() == 0) { + logDebug(s"region not allocated") + Thread.sleep(1000) + } + logDebug(s"region allocated $r") + + } + admin.close() + connection.close() + } + } + + /** + * + * @param data + * @param overwrite + */ + override def insert(data: DataFrame, overwrite: Boolean): Unit = { + val jobConfig: JobConf = new JobConf(hbaseConf, this.getClass) + jobConfig.setOutputFormat(classOf[TableOutputFormat]) + jobConfig.set(TableOutputFormat.OUTPUT_TABLE, catalog.name) + var count = 0 + val rkFields = catalog.getRowKey + val rkIdxedFields = rkFields.map{ case x => + (schema.fieldIndex(x.colName), x) + } + val colsIdxedFields = schema + .fieldNames + .partition( x => rkFields.map(_.colName).contains(x)) + ._2.map(x => (schema.fieldIndex(x), catalog.getField(x))) + val rdd = data.rdd //df.queryExecution.toRdd + def convertToPut(row: Row) = { + // construct bytes for row key + val rowBytes = rkIdxedFields.map { case (x, y) => + Utils.toBytes(row(x), y) + } + val rLen = rowBytes.foldLeft(0) { case (x, y) => + x + y.length + } + val rBytes = new Array[Byte](rLen) + var offset = 0 + rowBytes.foreach { x => + System.arraycopy(x, 0, rBytes, offset, x.length) + offset += x.length + } + val put = new Put(rBytes) + + colsIdxedFields.foreach { case (x, y) => + val b = Utils.toBytes(row(x), y) + put.addColumn(Bytes.toBytes(y.cf), Bytes.toBytes(y.col), b) + } + count += 1 + (new ImmutableBytesWritable, put) + } + rdd.map(convertToPut(_)).saveAsHadoopDataset(jobConfig) + } + /** * Here we are building the functionality to populate the resulting RDD[Row] * Here is where we will do the following: @@ -356,7 +448,8 @@ class ScanRange(var upperBound:Array[Byte], var isUpperBoundEqualTo:Boolean, /** * Function to merge another scan object through a AND operation - * @param other Other scan object + * + * @param other Other scan object */ def mergeIntersect(other:ScanRange): Unit = { val upperBoundCompare = compareRange(upperBound, other.upperBound) @@ -376,7 +469,8 @@ class ScanRange(var upperBound:Array[Byte], var isUpperBoundEqualTo:Boolean, /** * Function to merge another scan object through a OR operation - * @param other Other scan object + * + * @param other Other scan object */ def mergeUnion(other:ScanRange): Unit = { diff --git a/hbase-spark/src/main/scala/org/apache/hadoop/hbase/spark/datasources/Utils.scala b/hbase-spark/src/main/scala/org/apache/hadoop/hbase/spark/datasources/Utils.scala new file mode 100644 index 0000000..a329599 --- /dev/null +++ b/hbase-spark/src/main/scala/org/apache/hadoop/hbase/spark/datasources/Utils.scala @@ -0,0 +1,44 @@ + +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.hadoop.hbase.spark.datasources + +import org.apache.hadoop.hbase.util.Bytes +import org.apache.hadoop.hbase.util.Bytes +import org.apache.spark.sql.datasources.hbase.Field +import org.apache.spark.unsafe.types.UTF8String + +object Utils { + // convert input to data type + def toBytes(input: Any, field: Field): Array[Byte] = { + input match { + case data: Boolean => Bytes.toBytes(data) + case data: Byte => Array(data) + case data: Array[Byte] => data + case data: Double => Bytes.toBytes(data) + case data: Float => Bytes.toBytes(data) + case data: Int => Bytes.toBytes(data) + case data: Long => Bytes.toBytes(data) + case data: Short => Bytes.toBytes(data) + case data: UTF8String => data.getBytes + case data: String => Bytes.toBytes(data) + //Bytes.toBytes(input.asInstanceOf[String])//input.asInstanceOf[UTF8String].getBytes + case _ => throw new Exception(s"unsupported data type ${field.dt}") //TODO + } + } +} diff --git a/hbase-spark/src/test/scala/org/apache/hadoop/hbase/spark/DefaultSourceSuite.scala b/hbase-spark/src/test/scala/org/apache/hadoop/hbase/spark/DefaultSourceSuite.scala index 2987ec6..a2aa3c6 100644 --- a/hbase-spark/src/test/scala/org/apache/hadoop/hbase/spark/DefaultSourceSuite.scala +++ b/hbase-spark/src/test/scala/org/apache/hadoop/hbase/spark/DefaultSourceSuite.scala @@ -26,6 +26,26 @@ import org.apache.spark.sql.{DataFrame, SQLContext} import org.apache.spark.{SparkConf, SparkContext, Logging} import org.scalatest.{BeforeAndAfterAll, BeforeAndAfterEach, FunSuite} +case class HBaseRecord( + col0: String, + col1: String, + col2: Double, + col3: Float, + col4: Int, + col5: Long) + +object HBaseRecord { + def apply(i: Int, t: String): HBaseRecord = { + val s = s"""row${"%03d".format(i)}""" + HBaseRecord(s, + s, + i.toDouble, + i.toFloat, + i, + i.toLong) + } +} + class DefaultSourceSuite extends FunSuite with BeforeAndAfterEach with BeforeAndAfterAll with Logging { @transient var sc: SparkContext = null @@ -63,6 +83,7 @@ BeforeAndAfterEach with BeforeAndAfterAll with Logging { sparkConf.set(HBaseSparkConf.BLOCK_CACHE_ENABLE, "true") sparkConf.set(HBaseSparkConf.BATCH_NUM, "100") sparkConf.set(HBaseSparkConf.CACHE_SIZE, "100") + sc = new SparkContext("local", "test", sparkConf) val connection = ConnectionFactory.createConnection(TEST_UTIL.getConfiguration) @@ -759,4 +780,60 @@ BeforeAndAfterEach with BeforeAndAfterAll with Logging { assert(executionRules.dynamicLogicExpression == null) } + + def writeCatalog = s"""{ + |"table":{"namespace":"default", "name":"table1"}, + |"rowkey":"key", + |"columns":{ + |"col0":{"cf":"rowkey", "col":"key", "type":"string"}, + |"col1":{"cf":"cf1", "col":"col1", "type":"string"}, + |"col2":{"cf":"cf2", "col":"col2", "type":"double"}, + |"col3":{"cf":"cf3", "col":"col3", "type":"float"}, + |"col4":{"cf":"cf4", "col":"col4", "type":"int"}, + |"col5":{"cf":"cf5", "col":"col5", "type":"bigint"}} + |} + |}""".stripMargin + + def withCatalog(cat: String): DataFrame = { + sqlContext + .read + .options(Map(HBaseTableCatalog.tableCatalog->cat)) + .format("org.apache.hadoop.hbase.spark") + .load() + } + + test("populate table") { + val sql = sqlContext + import sql.implicits._ + val data = (0 to 255).map { i => + HBaseRecord(i, "extra") + } + sc.parallelize(data).toDF.write.options( + Map(HBaseTableCatalog.tableCatalog -> writeCatalog, HBaseTableCatalog.newTable -> "5")) + .format("org.apache.hadoop.hbase.spark") + .save() + } + + test("empty column") { + val df = withCatalog(writeCatalog) + df.registerTempTable("table0") + val c = sqlContext.sql("select count(1) from table0").rdd.collect()(0)(0).asInstanceOf[Long] + assert(c == 256) + } + + test("full query") { + val df = withCatalog(writeCatalog) + df.show + assert(df.count() == 256) + } + + test("filtered query0") { + val sql = sqlContext + import sql.implicits._ + val df = withCatalog(writeCatalog) + val s = df.filter($"col0" <= "row005") + .select("col0", "col1") + s.show + assert(s.count() == 6) + } }