diff --git hbase-spark/src/main/java/org/apache/hadoop/hbase/spark/SparkSQLPushDownFilter.java hbase-spark/src/main/java/org/apache/hadoop/hbase/spark/SparkSQLPushDownFilter.java index c3fd25c..fe23155 100644 --- hbase-spark/src/main/java/org/apache/hadoop/hbase/spark/SparkSQLPushDownFilter.java +++ hbase-spark/src/main/java/org/apache/hadoop/hbase/spark/SparkSQLPushDownFilter.java @@ -27,8 +27,10 @@ import org.apache.hadoop.hbase.filter.FilterBase; import org.apache.hadoop.hbase.spark.protobuf.generated.FilterProtos; import org.apache.hadoop.hbase.util.ByteStringer; import org.apache.hadoop.hbase.util.Bytes; +import org.apache.spark.sql.datasources.hbase.Field; import scala.collection.mutable.MutableList; + import java.io.IOException; import java.util.HashMap; import java.util.List; @@ -66,7 +68,7 @@ public class SparkSQLPushDownFilter extends FilterBase{ public SparkSQLPushDownFilter(DynamicLogicExpression dynamicLogicExpression, byte[][] valueFromQueryArray, - MutableList columnDefinitions) { + MutableList fields) { this.dynamicLogicExpression = dynamicLogicExpression; this.valueFromQueryArray = valueFromQueryArray; @@ -74,12 +76,11 @@ public class SparkSQLPushDownFilter extends FilterBase{ this.currentCellToColumnIndexMap = new HashMap<>(); - for (int i = 0; i < columnDefinitions.size(); i++) { - SchemaQualifierDefinition definition = columnDefinitions.get(i).get(); + for (int i = 0; i < fields.size(); i++) { + Field field = fields.apply(i); - ByteArrayComparable familyByteComparable = - new ByteArrayComparable(definition.columnFamilyBytes(), - 0, definition.columnFamilyBytes().length); + byte[] cfBytes = field.cfBytes(); + ByteArrayComparable familyByteComparable = new ByteArrayComparable(cfBytes, 0, cfBytes.length); HashMap qualifierIndexMap = currentCellToColumnIndexMap.get(familyByteComparable); @@ -88,11 +89,10 @@ public class SparkSQLPushDownFilter extends FilterBase{ qualifierIndexMap = new HashMap<>(); currentCellToColumnIndexMap.put(familyByteComparable, qualifierIndexMap); } - ByteArrayComparable qualifierByteComparable = - new ByteArrayComparable(definition.qualifierBytes(), 0, - definition.qualifierBytes().length); + byte[] qBytes = field.colBytes(); + ByteArrayComparable qualifierByteComparable = new ByteArrayComparable(qBytes, 0, qBytes.length); - qualifierIndexMap.put(qualifierByteComparable, definition.columnName()); + qualifierIndexMap.put(qualifierByteComparable, field.colName()); } } diff --git hbase-spark/src/main/scala/org/apache/hadoop/hbase/spark/DefaultSource.scala hbase-spark/src/main/scala/org/apache/hadoop/hbase/spark/DefaultSource.scala index 844b5b5..469069d 100644 --- hbase-spark/src/main/scala/org/apache/hadoop/hbase/spark/DefaultSource.scala +++ hbase-spark/src/main/scala/org/apache/hadoop/hbase/spark/DefaultSource.scala @@ -29,7 +29,8 @@ import org.apache.hadoop.hbase.util.{Bytes, PositionedByteRange, SimplePositione import org.apache.hadoop.hbase.{HBaseConfiguration, TableName} import org.apache.spark.Logging import org.apache.spark.rdd.RDD -import org.apache.spark.sql.types.DataType +import org.apache.spark.sql.datasources.hbase.{Field, HBaseTableCatalog} +import org.apache.spark.sql.types.{DataType => SparkDataType} import org.apache.spark.sql.{Row, SQLContext} import org.apache.spark.sql.sources._ import org.apache.spark.sql.types._ @@ -48,13 +49,6 @@ import scala.collection.mutable * Through the HBase Bytes object commands. */ class DefaultSource extends RelationProvider with Logging { - - val TABLE_KEY:String = "hbase.table" - val SCHEMA_COLUMNS_MAPPING_KEY:String = "hbase.columns.mapping" - val HBASE_CONFIG_RESOURCES_LOCATIONS:String = "hbase.config.resources" - val USE_HBASE_CONTEXT:String = "hbase.use.hbase.context" - val PUSH_DOWN_COLUMN_FILTER:String = "hbase.push.down.column.filter" - /** * Is given input from SparkSQL to construct a BaseRelation * @param sqlContext SparkSQL context @@ -64,87 +58,26 @@ class DefaultSource extends RelationProvider with Logging { override def createRelation(sqlContext: SQLContext, parameters: Map[String, String]): BaseRelation = { - - - val tableName = parameters.get(TABLE_KEY) - if (tableName.isEmpty) - new IllegalArgumentException("Invalid value for " + TABLE_KEY +" '" + tableName + "'") - - val schemaMappingString = parameters.getOrElse(SCHEMA_COLUMNS_MAPPING_KEY, "") - val hbaseConfigResources = parameters.getOrElse(HBASE_CONFIG_RESOURCES_LOCATIONS, "") - val useHBaseReources = parameters.getOrElse(USE_HBASE_CONTEXT, "true") - val usePushDownColumnFilter = parameters.getOrElse(PUSH_DOWN_COLUMN_FILTER, "true") - - new HBaseRelation(tableName.get, - generateSchemaMappingMap(schemaMappingString), - hbaseConfigResources, - useHBaseReources.equalsIgnoreCase("true"), - usePushDownColumnFilter.equalsIgnoreCase("true"), - parameters)(sqlContext) - } - - /** - * Reads the SCHEMA_COLUMNS_MAPPING_KEY and converts it to a map of - * SchemaQualifierDefinitions with the original sql column name as the key - * @param schemaMappingString The schema mapping string from the SparkSQL map - * @return A map of definitions keyed by the SparkSQL column name - */ - def generateSchemaMappingMap(schemaMappingString:String): - java.util.HashMap[String, SchemaQualifierDefinition] = { - try { - val columnDefinitions = schemaMappingString.split(',') - val resultingMap = new java.util.HashMap[String, SchemaQualifierDefinition]() - columnDefinitions.map(cd => { - val parts = cd.trim.split(' ') - - //Make sure we get three parts - // - if (parts.length == 3) { - val hbaseDefinitionParts = if (parts(2).charAt(0) == ':') { - Array[String]("", "key") - } else { - parts(2).split(':') - } - resultingMap.put(parts(0), new SchemaQualifierDefinition(parts(0), - parts(1), hbaseDefinitionParts(0), hbaseDefinitionParts(1))) - } else { - throw new IllegalArgumentException("Invalid value for schema mapping '" + cd + - "' should be ' :' " + - "for columns and ' :' for rowKeys") - } - }) - resultingMap - } catch { - case e:Exception => throw - new IllegalArgumentException("Invalid value for " + SCHEMA_COLUMNS_MAPPING_KEY + - " '" + schemaMappingString + "'", e ) - } + new HBaseRelation(parameters, None)(sqlContext) } } /** * Implementation of Spark BaseRelation that will build up our scan logic * , do the scan pruning, filter push down, and value conversions - * - * @param tableName HBase table that we plan to read from - * @param schemaMappingDefinition SchemaMapping information to map HBase - * Qualifiers to SparkSQL columns - * @param configResources Optional comma separated list of config resources - * to get based on their URI - * @param useHBaseContext If true this will look to see if - * HBaseContext.latest is populated to use that - * connection information * @param sqlContext SparkSQL context */ -case class HBaseRelation (val tableName:String, - val schemaMappingDefinition: - java.util.HashMap[String, SchemaQualifierDefinition], - val configResources:String, - val useHBaseContext:Boolean, - val usePushDownColumnFilter:Boolean, - @transient parameters: Map[String, String] ) ( - @transient val sqlContext:SQLContext) +case class HBaseRelation ( + @transient parameters: Map[String, String], + userSpecifiedSchema: Option[StructType] + )(@transient val sqlContext: SQLContext) extends BaseRelation with PrunedFilteredScan with Logging { + val catalog = HBaseTableCatalog(parameters) + def tableName = catalog.name + val configResources = parameters.getOrElse(HBaseSparkConf.HBASE_CONFIG_RESOURCES_LOCATIONS, "") + val useHBaseContext = parameters.get(HBaseSparkConf.USE_HBASE_CONTEXT).map(_.toBoolean).getOrElse(true) + val usePushDownColumnFilter = parameters.get(HBaseSparkConf.PUSH_DOWN_COLUMN_FILTER) + .map(_.toBoolean).getOrElse(true) // The user supplied per table parameter will overwrite global ones in SparkConf val blockCacheEnable = parameters.get(HBaseSparkConf.BLOCK_CACHE_ENABLE).map(_.toBoolean) @@ -181,28 +114,7 @@ case class HBaseRelation (val tableName:String, * * @return schema generated from the SCHEMA_COLUMNS_MAPPING_KEY value */ - override def schema: StructType = { - - val metadataBuilder = new MetadataBuilder() - - val structFieldArray = new Array[StructField](schemaMappingDefinition.size()) - - val schemaMappingDefinitionIt = schemaMappingDefinition.values().iterator() - var indexCounter = 0 - while (schemaMappingDefinitionIt.hasNext) { - val c = schemaMappingDefinitionIt.next() - - val metadata = metadataBuilder.putString("name", c.columnName).build() - val structField = - new StructField(c.columnName, c.columnSparkSqlType, nullable = true, metadata) - - structFieldArray(indexCounter) = structField - indexCounter += 1 - } - - val result = new StructType(structFieldArray) - result - } + override val schema: StructType = userSpecifiedSchema.getOrElse(catalog.toDataType) /** * Here we are building the functionality to populate the resulting RDD[Row] @@ -218,7 +130,6 @@ case class HBaseRelation (val tableName:String, */ override def buildScan(requiredColumns: Array[String], filters: Array[Filter]): RDD[Row] = { - val pushDownTuple = buildPushDownPredicatesResource(filters) val pushDownRowKeyFilter = pushDownTuple._1 var pushDownDynamicLogicExpression = pushDownTuple._2 @@ -236,17 +147,13 @@ case class HBaseRelation (val tableName:String, logDebug("valueArray: " + valueArray.length) val requiredQualifierDefinitionList = - new mutable.MutableList[SchemaQualifierDefinition] + new mutable.MutableList[Field] requiredColumns.foreach( c => { - val definition = schemaMappingDefinition.get(c) - requiredQualifierDefinitionList += definition + val field = catalog.getField(c) + requiredQualifierDefinitionList += field }) - //Create a local variable so that scala doesn't have to - // serialize the whole HBaseRelation Object - val serializableDefinitionMap = schemaMappingDefinition - //retain the information for unit testing checks DefaultSourceStaticUtils.populateLatestExecutionRules(pushDownRowKeyFilter, pushDownDynamicLogicExpression) @@ -258,8 +165,8 @@ case class HBaseRelation (val tableName:String, pushDownRowKeyFilter.points.foreach(p => { val get = new Get(p) requiredQualifierDefinitionList.foreach( d => { - if (d.columnFamilyBytes.length > 0) - get.addColumn(d.columnFamilyBytes, d.qualifierBytes) + if (d.isRowKey) + get.addColumn(d.cfBytes, d.colBytes) }) getList.add(get) }) @@ -276,7 +183,7 @@ case class HBaseRelation (val tableName:String, var resultRDD: RDD[Row] = { val tmp = hRdd.map{ r => Row.fromSeq(requiredColumns.map(c => - DefaultSourceStaticUtils.getValue(c, serializableDefinitionMap, r))) + DefaultSourceStaticUtils.getValue(catalog.getField(c), r))) } if (tmp.partitions.size > 0) { tmp @@ -291,11 +198,10 @@ case class HBaseRelation (val tableName:String, scan.setBatch(batchNum) scan.setCaching(cacheSize) requiredQualifierDefinitionList.foreach( d => - scan.addColumn(d.columnFamilyBytes, d.qualifierBytes)) + scan.addColumn(d.cfBytes, d.colBytes)) val rdd = hbaseContext.hbaseRDD(TableName.valueOf(tableName), scan).map(r => { - Row.fromSeq(requiredColumns.map(c => DefaultSourceStaticUtils.getValue(c, - serializableDefinitionMap, r._2))) + Row.fromSeq(requiredColumns.map(c => DefaultSourceStaticUtils.getValue(catalog.getField(c), r._2))) }) resultRDD=rdd } @@ -337,74 +243,73 @@ case class HBaseRelation (val tableName:String, filter match { case EqualTo(attr, value) => - val columnDefinition = schemaMappingDefinition.get(attr) - if (columnDefinition != null) { - if (columnDefinition.columnFamily.isEmpty) { + val field = catalog.getField(attr) + if (field != null) { + if (field.isRowKey) { parentRowKeyFilter.mergeIntersect(new RowKeyFilter( - DefaultSourceStaticUtils.getByteValue(attr, - schemaMappingDefinition, value.toString), null)) + DefaultSourceStaticUtils.getByteValue(field, + value.toString), null)) } val byteValue = - DefaultSourceStaticUtils.getByteValue(attr, - schemaMappingDefinition, value.toString) + DefaultSourceStaticUtils.getByteValue(field, value.toString) valueArray += byteValue } new EqualLogicExpression(attr, valueArray.length - 1, false) case LessThan(attr, value) => - val columnDefinition = schemaMappingDefinition.get(attr) - if (columnDefinition != null) { - if (columnDefinition.columnFamily.isEmpty) { + val field = catalog.getField(attr) + if (field != null) { + if (field.isRowKey) { parentRowKeyFilter.mergeIntersect(new RowKeyFilter(null, - new ScanRange(DefaultSourceStaticUtils.getByteValue(attr, - schemaMappingDefinition, value.toString), false, + new ScanRange(DefaultSourceStaticUtils.getByteValue(field, + value.toString), false, new Array[Byte](0), true))) } val byteValue = - DefaultSourceStaticUtils.getByteValue(attr, - schemaMappingDefinition, value.toString) + DefaultSourceStaticUtils.getByteValue(catalog.getField(attr), + value.toString) valueArray += byteValue } new LessThanLogicExpression(attr, valueArray.length - 1) case GreaterThan(attr, value) => - val columnDefinition = schemaMappingDefinition.get(attr) - if (columnDefinition != null) { - if (columnDefinition.columnFamily.isEmpty) { + val field = catalog.getField(attr) + if (field != null) { + if (field.isRowKey) { parentRowKeyFilter.mergeIntersect(new RowKeyFilter(null, - new ScanRange(null, true, DefaultSourceStaticUtils.getByteValue(attr, - schemaMappingDefinition, value.toString), false))) + new ScanRange(null, true, DefaultSourceStaticUtils.getByteValue(field, + value.toString), false))) } val byteValue = - DefaultSourceStaticUtils.getByteValue(attr, - schemaMappingDefinition, value.toString) + DefaultSourceStaticUtils.getByteValue(field, + value.toString) valueArray += byteValue } new GreaterThanLogicExpression(attr, valueArray.length - 1) case LessThanOrEqual(attr, value) => - val columnDefinition = schemaMappingDefinition.get(attr) - if (columnDefinition != null) { - if (columnDefinition.columnFamily.isEmpty) { + val field = catalog.getField(attr) + if (field != null) { + if (field.isRowKey) { parentRowKeyFilter.mergeIntersect(new RowKeyFilter(null, - new ScanRange(DefaultSourceStaticUtils.getByteValue(attr, - schemaMappingDefinition, value.toString), true, + new ScanRange(DefaultSourceStaticUtils.getByteValue(field, + value.toString), true, new Array[Byte](0), true))) } val byteValue = - DefaultSourceStaticUtils.getByteValue(attr, - schemaMappingDefinition, value.toString) + DefaultSourceStaticUtils.getByteValue(catalog.getField(attr), + value.toString) valueArray += byteValue } new LessThanOrEqualLogicExpression(attr, valueArray.length - 1) case GreaterThanOrEqual(attr, value) => - val columnDefinition = schemaMappingDefinition.get(attr) - if (columnDefinition != null) { - if (columnDefinition.columnFamily.isEmpty) { + val field = catalog.getField(attr) + if (field != null) { + if (field.isRowKey) { parentRowKeyFilter.mergeIntersect(new RowKeyFilter(null, - new ScanRange(null, true, DefaultSourceStaticUtils.getByteValue(attr, - schemaMappingDefinition, value.toString), true))) + new ScanRange(null, true, DefaultSourceStaticUtils.getByteValue(field, + value.toString), true))) } val byteValue = - DefaultSourceStaticUtils.getByteValue(attr, - schemaMappingDefinition, value.toString) + DefaultSourceStaticUtils.getByteValue(catalog.getField(attr), + value.toString) valueArray += byteValue } @@ -436,32 +341,6 @@ case class HBaseRelation (val tableName:String, } /** - * Construct to contains column data that spend SparkSQL and HBase - * - * @param columnName SparkSQL column name - * @param colType SparkSQL column type - * @param columnFamily HBase column family - * @param qualifier HBase qualifier name - */ -case class SchemaQualifierDefinition(columnName:String, - colType:String, - columnFamily:String, - qualifier:String) extends Serializable { - val columnFamilyBytes = Bytes.toBytes(columnFamily) - val qualifierBytes = Bytes.toBytes(qualifier) - val columnSparkSqlType:DataType = if (colType.equals("BOOLEAN")) BooleanType - else if (colType.equals("TINYINT")) IntegerType - else if (colType.equals("INT")) IntegerType - else if (colType.equals("BIGINT")) LongType - else if (colType.equals("FLOAT")) FloatType - else if (colType.equals("DOUBLE")) DoubleType - else if (colType.equals("STRING")) StringType - else if (colType.equals("TIMESTAMP")) TimestampType - else if (colType.equals("DECIMAL")) StringType - else throw new IllegalArgumentException("Unsupported column type :" + colType) -} - -/** * Construct to contain a single scan ranges information. Also * provide functions to merge with other scan ranges through AND * or OR operators @@ -788,35 +667,6 @@ class ColumnFilterCollection { }) } - /** - * This will collect all the filter information in a way that is optimized - * for the HBase filter commend. Allowing the filter to be accessed - * with columnFamily and qualifier information - * - * @param schemaDefinitionMap Schema Map that will help us map the right filters - * to the correct columns - * @return HashMap oc column filters - */ - def generateFamilyQualifiterFilterMap(schemaDefinitionMap: - java.util.HashMap[String, - SchemaQualifierDefinition]): - util.HashMap[ColumnFamilyQualifierMapKeyWrapper, ColumnFilter] = { - val familyQualifierFilterMap = - new util.HashMap[ColumnFamilyQualifierMapKeyWrapper, ColumnFilter]() - - columnFilterMap.foreach( e => { - val definition = schemaDefinitionMap.get(e._1) - //Don't add rowKeyFilter - if (definition.columnFamilyBytes.size > 0) { - familyQualifierFilterMap.put( - new ColumnFamilyQualifierMapKeyWrapper( - definition.columnFamilyBytes, 0, definition.columnFamilyBytes.length, - definition.qualifierBytes, 0, definition.qualifierBytes.length), e._2) - } - }) - familyQualifierFilterMap - } - override def toString:String = { val strBuilder = new StringBuilder columnFilterMap.foreach( e => strBuilder.append(e)) @@ -836,7 +686,7 @@ object DefaultSourceStaticUtils { val rawDouble = new RawDouble val rawString = RawString.ASCENDING - val byteRange = new ThreadLocal[PositionedByteRange]{ + val byteRange = new ThreadLocal[PositionedByteRange] { override def initialValue(): PositionedByteRange = { val range = new SimplePositionedMutableByteRange() range.setOffset(0) @@ -844,11 +694,11 @@ object DefaultSourceStaticUtils { } } - def getFreshByteRange(bytes:Array[Byte]): PositionedByteRange = { + def getFreshByteRange(bytes: Array[Byte]): PositionedByteRange = { getFreshByteRange(bytes, 0, bytes.length) } - def getFreshByteRange(bytes:Array[Byte], offset:Int = 0, length:Int): + def getFreshByteRange(bytes: Array[Byte], offset: Int = 0, length: Int): PositionedByteRange = { byteRange.get().set(bytes).setLength(length).setOffset(offset) } @@ -867,7 +717,7 @@ object DefaultSourceStaticUtils { * @param dynamicLogicExpression The dynamicLogicExpression used in the last query */ def populateLatestExecutionRules(rowKeyFilter: RowKeyFilter, - dynamicLogicExpression: DynamicLogicExpression):Unit = { + dynamicLogicExpression: DynamicLogicExpression): Unit = { lastFiveExecutionRules.add(new ExecutionRuleForUnitTesting( rowKeyFilter, dynamicLogicExpression)) while (lastFiveExecutionRules.size() > 5) { @@ -879,25 +729,16 @@ object DefaultSourceStaticUtils { * This method will convert the result content from HBase into the * SQL value type that is requested by the Spark SQL schema definition * - * @param columnName The name of the SparkSQL Column - * @param schemaMappingDefinition The schema definition map + * @param field The structure of the SparkSQL Column * @param r The result object from HBase * @return The converted object type */ - def getValue(columnName: String, - schemaMappingDefinition: - java.util.HashMap[String, SchemaQualifierDefinition], - r: Result): Any = { - - val columnDef = schemaMappingDefinition.get(columnName) - - if (columnDef == null) throw new IllegalArgumentException("Unknown column:" + columnName) - - - if (columnDef.columnFamilyBytes.isEmpty) { + def getValue(field: Field, + r: Result): Any = { + if (field.isRowKey) { val row = r.getRow - columnDef.columnSparkSqlType match { + field.dt match { case IntegerType => rawInteger.decode(getFreshByteRange(row)) case LongType => rawLong.decode(getFreshByteRange(row)) case FloatType => rawFloat.decode(getFreshByteRange(row)) @@ -908,9 +749,9 @@ object DefaultSourceStaticUtils { } } else { val cellByteValue = - r.getColumnLatestCell(columnDef.columnFamilyBytes, columnDef.qualifierBytes) + r.getColumnLatestCell(field.cfBytes, field.colBytes) if (cellByteValue == null) null - else columnDef.columnSparkSqlType match { + else field.dt match { case IntegerType => rawInteger.decode(getFreshByteRange(cellByteValue.getValueArray, cellByteValue.getValueOffset, cellByteValue.getValueLength)) case LongType => rawLong.decode(getFreshByteRange(cellByteValue.getValueArray, @@ -933,52 +774,41 @@ object DefaultSourceStaticUtils { * This will convert the value from SparkSQL to be stored into HBase using the * right byte Type * - * @param columnName SparkSQL column name - * @param schemaMappingDefinition Schema definition map * @param value String value from SparkSQL * @return Returns the byte array to go into HBase */ - def getByteValue(columnName: String, - schemaMappingDefinition: - java.util.HashMap[String, SchemaQualifierDefinition], - value: String): Array[Byte] = { - - val columnDef = schemaMappingDefinition.get(columnName) - - if (columnDef == null) { - throw new IllegalArgumentException("Unknown column:" + columnName) - } else { - columnDef.columnSparkSqlType match { - case IntegerType => - val result = new Array[Byte](Bytes.SIZEOF_INT) - val localDataRange = getFreshByteRange(result) - rawInteger.encode(localDataRange, value.toInt) - localDataRange.getBytes - case LongType => - val result = new Array[Byte](Bytes.SIZEOF_LONG) - val localDataRange = getFreshByteRange(result) - rawLong.encode(localDataRange, value.toLong) - localDataRange.getBytes - case FloatType => - val result = new Array[Byte](Bytes.SIZEOF_FLOAT) - val localDataRange = getFreshByteRange(result) - rawFloat.encode(localDataRange, value.toFloat) - localDataRange.getBytes - case DoubleType => - val result = new Array[Byte](Bytes.SIZEOF_DOUBLE) - val localDataRange = getFreshByteRange(result) - rawDouble.encode(localDataRange, value.toDouble) - localDataRange.getBytes - case StringType => - Bytes.toBytes(value) - case TimestampType => - val result = new Array[Byte](Bytes.SIZEOF_LONG) - val localDataRange = getFreshByteRange(result) - rawLong.encode(localDataRange, value.toLong) - localDataRange.getBytes - - case _ => Bytes.toBytes(value) - } + def getByteValue(field: Field, + value: String): Array[Byte] = { + field.dt match { + case IntegerType => + val result = new Array[Byte](Bytes.SIZEOF_INT) + val localDataRange = getFreshByteRange(result) + rawInteger.encode(localDataRange, value.toInt) + localDataRange.getBytes + case LongType => + val result = new Array[Byte](Bytes.SIZEOF_LONG) + val localDataRange = getFreshByteRange(result) + rawLong.encode(localDataRange, value.toLong) + localDataRange.getBytes + case FloatType => + val result = new Array[Byte](Bytes.SIZEOF_FLOAT) + val localDataRange = getFreshByteRange(result) + rawFloat.encode(localDataRange, value.toFloat) + localDataRange.getBytes + case DoubleType => + val result = new Array[Byte](Bytes.SIZEOF_DOUBLE) + val localDataRange = getFreshByteRange(result) + rawDouble.encode(localDataRange, value.toDouble) + localDataRange.getBytes + case StringType => + Bytes.toBytes(value) + case TimestampType => + val result = new Array[Byte](Bytes.SIZEOF_LONG) + val localDataRange = getFreshByteRange(result) + rawLong.encode(localDataRange, value.toLong) + localDataRange.getBytes + + case _ => Bytes.toBytes(value) } } } diff --git hbase-spark/src/main/scala/org/apache/hadoop/hbase/spark/datasources/HBaseSparkConf.scala hbase-spark/src/main/scala/org/apache/hadoop/hbase/spark/datasources/HBaseSparkConf.scala index 5e11356..ca44d42 100644 --- hbase-spark/src/main/scala/org/apache/hadoop/hbase/spark/datasources/HBaseSparkConf.scala +++ hbase-spark/src/main/scala/org/apache/hadoop/hbase/spark/datasources/HBaseSparkConf.scala @@ -31,4 +31,9 @@ object HBaseSparkConf{ val defaultBatchNum = 1000 val BULKGET_SIZE = "spark.hbase.bulkGetSize" val defaultBulkGetSize = 1000 + + val HBASE_CONFIG_RESOURCES_LOCATIONS = "hbase.config.resources" + val USE_HBASE_CONTEXT = "hbase.use.hbase.context" + val PUSH_DOWN_COLUMN_FILTER = "hbase.pushdown.column.filter" + val defaultPushDownColumnFilter = true } diff --git hbase-spark/src/main/scala/org/apache/hadoop/hbase/spark/datasources/HBaseTableScanRDD.scala hbase-spark/src/main/scala/org/apache/hadoop/hbase/spark/datasources/HBaseTableScanRDD.scala index d859957..2e05651 100644 --- hbase-spark/src/main/scala/org/apache/hadoop/hbase/spark/datasources/HBaseTableScanRDD.scala +++ hbase-spark/src/main/scala/org/apache/hadoop/hbase/spark/datasources/HBaseTableScanRDD.scala @@ -23,6 +23,7 @@ import org.apache.hadoop.hbase.client._ import org.apache.hadoop.hbase.spark._ import org.apache.hadoop.hbase.spark.hbase._ import org.apache.hadoop.hbase.spark.datasources.HBaseResources._ +import org.apache.spark.sql.datasources.hbase.Field import org.apache.spark.{SparkEnv, TaskContext, Logging, Partition} import org.apache.spark.rdd.RDD @@ -31,7 +32,7 @@ import scala.collection.mutable class HBaseTableScanRDD(relation: HBaseRelation, val hbaseContext: HBaseContext, @transient val filter: Option[SparkSQLPushDownFilter] = None, - val columns: Seq[SchemaQualifierDefinition] = Seq.empty + val columns: Seq[Field] = Seq.empty )extends RDD[Result](relation.sqlContext.sparkContext, Nil) with Logging { private def sparkConf = SparkEnv.get.conf @transient var ranges = Seq.empty[Range] @@ -98,15 +99,15 @@ class HBaseTableScanRDD(relation: HBaseRelation, tbr: TableResource, g: Seq[Array[Byte]], filter: Option[SparkSQLPushDownFilter], - columns: Seq[SchemaQualifierDefinition], + columns: Seq[Field], hbaseContext: HBaseContext): Iterator[Result] = { g.grouped(relation.bulkGetSize).flatMap{ x => val gets = new ArrayList[Get]() x.foreach{ y => val g = new Get(y) columns.foreach { d => - if (d.columnFamilyBytes.length > 0) { - g.addColumn(d.columnFamilyBytes, d.qualifierBytes) + if (!d.isRowKey) { + g.addColumn(d.cfBytes, d.colBytes) } } filter.foreach(g.setFilter(_)) @@ -149,7 +150,7 @@ class HBaseTableScanRDD(relation: HBaseRelation, private def buildScan(range: Range, filter: Option[SparkSQLPushDownFilter], - columns: Seq[SchemaQualifierDefinition]): Scan = { + columns: Seq[Field]): Scan = { val scan = (range.lower, range.upper) match { case (Some(Bound(a, b)), Some(Bound(c, d))) => new Scan(a, c) case (None, Some(Bound(c, d))) => new Scan(Array[Byte](), c) @@ -158,8 +159,8 @@ class HBaseTableScanRDD(relation: HBaseRelation, } columns.foreach { d => - if (d.columnFamilyBytes.length > 0) { - scan.addColumn(d.columnFamilyBytes, d.qualifierBytes) + if (!d.isRowKey) { + scan.addColumn(d.cfBytes, d.colBytes) } } scan.setCacheBlocks(relation.blockCacheEnable) diff --git hbase-spark/src/main/scala/org/apache/hadoop/hbase/spark/datasources/SerDes.scala hbase-spark/src/main/scala/org/apache/hadoop/hbase/spark/datasources/SerDes.scala new file mode 100644 index 0000000..a19099b --- /dev/null +++ hbase-spark/src/main/scala/org/apache/hadoop/hbase/spark/datasources/SerDes.scala @@ -0,0 +1,46 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.hadoop.hbase.spark.datasources + +import java.io.ByteArrayInputStream + +import org.apache.avro.Schema +import org.apache.avro.Schema.Type._ +import org.apache.avro.generic.GenericDatumReader +import org.apache.avro.generic.GenericDatumWriter +import org.apache.avro.generic.GenericRecord +import org.apache.avro.generic.{GenericDatumReader, GenericDatumWriter, GenericRecord} +import org.apache.avro.io._ +import org.apache.commons.io.output.ByteArrayOutputStream +import org.apache.hadoop.hbase.util.Bytes +import org.apache.hadoop.hbase.util.Bytes +import org.apache.spark.sql.types._ + +trait SerDes { + def serialize(value: Any): Array[Byte] + def deserialize(bytes: Array[Byte], start: Int, end: Int): Any +} + +class DoubleSerDes extends SerDes { + override def serialize(value: Any): Array[Byte] = Bytes.toBytes(value.asInstanceOf[Double]) + override def deserialize(bytes: Array[Byte], start: Int, end: Int): Any = { + Bytes.toDouble(bytes, start) + } +} + + diff --git hbase-spark/src/main/scala/org/apache/spark/sql/datasources/hbase/DataTypeParserWrapper.scala hbase-spark/src/main/scala/org/apache/spark/sql/datasources/hbase/DataTypeParserWrapper.scala new file mode 100644 index 0000000..1e56a3d --- /dev/null +++ hbase-spark/src/main/scala/org/apache/spark/sql/datasources/hbase/DataTypeParserWrapper.scala @@ -0,0 +1,30 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.datasources.hbase + +import org.apache.spark.sql.catalyst.SqlLexical +import org.apache.spark.sql.catalyst.util.DataTypeParser +import org.apache.spark.sql.types.DataType + +object DataTypeParserWrapper { + lazy val dataTypeParser = new DataTypeParser { + override val lexical = new SqlLexical + } + + def parse(dataTypeString: String): DataType = dataTypeParser.toDataType(dataTypeString) +} diff --git hbase-spark/src/main/scala/org/apache/spark/sql/datasources/hbase/HBaseTableCatalog.scala hbase-spark/src/main/scala/org/apache/spark/sql/datasources/hbase/HBaseTableCatalog.scala new file mode 100644 index 0000000..29acc90 --- /dev/null +++ hbase-spark/src/main/scala/org/apache/spark/sql/datasources/hbase/HBaseTableCatalog.scala @@ -0,0 +1,338 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.datasources.hbase + +import org.apache.hadoop.hbase.spark.datasources._ +import org.apache.hadoop.hbase.spark.hbase._ +import org.apache.hadoop.hbase.util.Bytes +import org.apache.spark.Logging +import org.apache.spark.sql.catalyst.util.DataTypeParser +import org.apache.spark.sql.types._ +import org.json4s.jackson.JsonMethods._ + +import scala.collection.mutable + +// Due the access issue defined in spark, we have to locate the file in this package. +// The definition of each column cell, which may be composite type +// TODO: add avro support +case class Field( + colName: String, + cf: String, + col: String, + sType: Option[String] = None, + avroSchema: Option[String] = None, + serdes: Option[SerDes]= None, + len: Int = -1) extends Logging { + override def toString = s"$colName $cf $col" + val isRowKey = cf == HBaseTableCatalog.rowKey + var start: Int = _ + + def cfBytes: Array[Byte] = { + if (isRowKey) { + Bytes.toBytes("") + } else { + Bytes.toBytes(cf) + } + } + def colBytes: Array[Byte] = { + if (isRowKey) { + Bytes.toBytes("key") + } else { + Bytes.toBytes(col) + } + } + + val dt = { + sType.map(DataTypeParser.parse(_)).get + } + + var length: Int = { + if (len == -1) { + dt match { + case BinaryType | StringType => -1 + case BooleanType => Bytes.SIZEOF_BOOLEAN + case ByteType => 1 + case DoubleType => Bytes.SIZEOF_DOUBLE + case FloatType => Bytes.SIZEOF_FLOAT + case IntegerType => Bytes.SIZEOF_INT + case LongType => Bytes.SIZEOF_LONG + case ShortType => Bytes.SIZEOF_SHORT + case _ => -1 + } + } else { + len + } + + } + + override def equals(other: Any): Boolean = other match { + case that: Field => + colName == that.colName && cf == that.cf && col == that.col + case _ => false + } +} + +// The row key definition, with each key refer to the col defined in Field, e.g., +// key1:key2:key3 +case class RowKey(k: String) { + val keys = k.split(":") + var fields: Seq[Field] = _ + var varLength = false + def length = { + if (varLength) { + -1 + } else { + fields.foldLeft(0){case (x, y) => + x + y.length + } + } + } +} +// The map between the column presented to Spark and the HBase field +case class SchemaMap(map: mutable.HashMap[String, Field]) { + def toFields = map.map { case (name, field) => + StructField(name, field.dt) + }.toSeq + + def fields = map.values + + def getField(name: String) = map(name) +} + + +// The definition of HBase and Relation relation schema +case class HBaseTableCatalog( + namespace: String, + name: String, + row: RowKey, + sMap: SchemaMap, + numReg: Int) extends Logging { + def toDataType = StructType(sMap.toFields) + def getField(name: String) = sMap.getField(name) + def getRowKey: Seq[Field] = row.fields + def getPrimaryKey= row.keys(0) + def getColumnFamilies = { + sMap.fields.map(_.cf).filter(_ != HBaseTableCatalog.rowKey) + } + + // Setup the start and length for each dimension of row key at runtime. + def dynSetupRowKey(rowKey: HBaseType) { + logDebug(s"length: ${rowKey.length}") + if(row.varLength) { + var start = 0 + row.fields.foreach { f => + logDebug(s"start: $start") + f.start = start + f.length = { + // If the length is not defined + if (f.length == -1) { + f.dt match { + case StringType => + var pos = rowKey.indexOf(HBaseTableCatalog.delimiter, start) + if (pos == -1 || pos > rowKey.length) { + // this is at the last dimension + pos = rowKey.length + } + pos - start + // We don't know the length, assume it extend to the end of the rowkey. + case _ => rowKey.length - start + } + } else { + f.length + } + } + start += f.length + } + } + } + + def initRowKey = { + val fields = sMap.fields.filter(_.cf == HBaseTableCatalog.rowKey) + row.fields = row.keys.flatMap(n => fields.find(_.col == n)) + // The length is determined at run time if it is string or binary and the length is undefined. + if (row.fields.filter(_.length == -1).isEmpty) { + var start = 0 + row.fields.foreach { f => + f.start = start + start += f.length + } + } else { + row.varLength = true + } + } + initRowKey +} + +object HBaseTableCatalog { + val newTable = "newtable" + // The json string specifying hbase catalog information + val tableCatalog = "catalog" + // The row key with format key1:key2 specifying table row key + val rowKey = "rowkey" + // The key for hbase table whose value specify namespace and table name + val table = "table" + // The namespace of hbase table + val nameSpace = "namespace" + // The name of hbase table + val tableName = "name" + // The name of columns in hbase catalog + val columns = "columns" + val cf = "cf" + val col = "col" + val `type` = "type" + // the name of avro schema json string + val avro = "avro" + val delimiter: Byte = 0 + val serdes = "serdes" + val length = "length" + + /** + * User provide table schema definition + * {"tablename":"name", "rowkey":"key1:key2", + * "columns":{"col1":{"cf":"cf1", "col":"col1", "type":"type1"}, + * "col2":{"cf":"cf2", "col":"col2", "type":"type2"}}} + * Note that any col in the rowKey, there has to be one corresponding col defined in columns + */ + def apply(params: Map[String, String]): HBaseTableCatalog = { + val parameters = convert(params) + // println(jString) + val jString = parameters(tableCatalog) + val map = parse(jString).values.asInstanceOf[Map[String, _]] + val tableMeta = map.get(table).get.asInstanceOf[Map[String, _]] + val nSpace = tableMeta.get(nameSpace).getOrElse("default").asInstanceOf[String] + val tName = tableMeta.get(tableName).get.asInstanceOf[String] + val cIter = map.get(columns).get.asInstanceOf[Map[String, Map[String, String]]].toIterator + val schemaMap = mutable.HashMap.empty[String, Field] + cIter.foreach { case (name, column) => + val sd = { + column.get(serdes).asInstanceOf[Option[String]].map(n => + Class.forName(n).newInstance().asInstanceOf[SerDes] + ) + } + val len = column.get(length).map(_.toInt).getOrElse(-1) + val sAvro = column.get(avro).map(parameters(_)) + val f = Field(name, column.getOrElse(cf, rowKey), + column.get(col).get, + column.get(`type`), + sAvro, sd, len) + schemaMap.+=((name, f)) + } + val numReg = parameters.get(newTable).map(x => x.toInt).getOrElse(0) + val rKey = RowKey(map.get(rowKey).get.asInstanceOf[String]) + HBaseTableCatalog(nSpace, tName, rKey, SchemaMap(schemaMap), numReg) + } + + val TABLE_KEY: String = "hbase.table" + val SCHEMA_COLUMNS_MAPPING_KEY: String = "hbase.columns.mapping" + + /* for backward compatibility. Convert the old definition to new json based definition formated as below + val catalog = s"""{ + |"table":{"namespace":"default", "name":"htable"}, + |"rowkey":"key1:key2", + |"columns":{ + |"col1":{"cf":"rowkey", "col":"key1", "type":"string"}, + |"col2":{"cf":"rowkey", "col":"key2", "type":"double"}, + |"col3":{"cf":"cf1", "col":"col2", "type":"binary"}, + |"col4":{"cf":"cf1", "col":"col3", "type":"timestamp"}, + |"col5":{"cf":"cf1", "col":"col4", "type":"double", "serdes":"${classOf[DoubleSerDes].getName}"}, + |"col6":{"cf":"cf1", "col":"col5", "type":"$map"}, + |"col7":{"cf":"cf1", "col":"col6", "type":"$array"}, + |"col8":{"cf":"cf1", "col":"col7", "type":"$arrayMap"} + |} + |}""".stripMargin + */ + def convert(parameters: Map[String, String]): Map[String, String] = { + val tableName = parameters.get(TABLE_KEY).getOrElse(null) + // if the hbase.table is not defined, we assume it is json format already. + if (tableName == null) return parameters + val schemaMappingString = parameters.getOrElse(SCHEMA_COLUMNS_MAPPING_KEY, "") + import scala.collection.JavaConverters._ + val schemaMap = generateSchemaMappingMap(schemaMappingString).asScala.map(_._2.asInstanceOf[SchemaQualifierDefinition]) + + val rowkey = schemaMap.filter { + _.columnFamily == "rowkey" + }.map(_.columnName) + val cols = schemaMap.map { x => + s""""${x.columnName}":{"cf":"${x.columnFamily}", "col":"${x.qualifier}", "type":"${x.colType}"}""".stripMargin + } + val jsonCatalog = + s"""{ + |"table":{"namespace":"default", "name":"${tableName}"}, + |"rowkey":"${rowkey.mkString(":")}", + |"columns":{ + |${cols.mkString(",")} + |} + |} + """.stripMargin + parameters ++ Map(HBaseTableCatalog.tableCatalog->jsonCatalog) + } + + /** + * Reads the SCHEMA_COLUMNS_MAPPING_KEY and converts it to a map of + * SchemaQualifierDefinitions with the original sql column name as the key + * + * @param schemaMappingString The schema mapping string from the SparkSQL map + * @return A map of definitions keyed by the SparkSQL column name + */ + def generateSchemaMappingMap(schemaMappingString:String): + java.util.HashMap[String, SchemaQualifierDefinition] = { + println(schemaMappingString) + try { + val columnDefinitions = schemaMappingString.split(',') + val resultingMap = new java.util.HashMap[String, SchemaQualifierDefinition]() + columnDefinitions.map(cd => { + val parts = cd.trim.split(' ') + + //Make sure we get three parts + // + if (parts.length == 3) { + val hbaseDefinitionParts = if (parts(2).charAt(0) == ':') { + Array[String]("rowkey", parts(0)) + } else { + parts(2).split(':') + } + resultingMap.put(parts(0), new SchemaQualifierDefinition(parts(0), + parts(1), hbaseDefinitionParts(0), hbaseDefinitionParts(1))) + } else { + throw new IllegalArgumentException("Invalid value for schema mapping '" + cd + + "' should be ' :' " + + "for columns and ' :' for rowKeys") + } + }) + resultingMap + } catch { + case e:Exception => throw + new IllegalArgumentException("Invalid value for " + SCHEMA_COLUMNS_MAPPING_KEY + + " '" + + schemaMappingString + "'", e ) + } + } + } + + /** + * Construct to contains column data that spend SparkSQL and HBase + * + * @param columnName SparkSQL column name + * @param colType SparkSQL column type + * @param columnFamily HBase column family + * @param qualifier HBase qualifier name + */ + case class SchemaQualifierDefinition(columnName:String, + colType:String, + columnFamily:String, + qualifier:String) diff --git hbase-spark/src/test/scala/org/apache/hadoop/hbase/spark/DefaultSourceSuite.scala hbase-spark/src/test/scala/org/apache/hadoop/hbase/spark/DefaultSourceSuite.scala index 04dd9ba..2987ec6 100644 --- hbase-spark/src/test/scala/org/apache/hadoop/hbase/spark/DefaultSourceSuite.scala +++ hbase-spark/src/test/scala/org/apache/hadoop/hbase/spark/DefaultSourceSuite.scala @@ -21,6 +21,7 @@ import org.apache.hadoop.hbase.client.{Put, ConnectionFactory} import org.apache.hadoop.hbase.spark.datasources.HBaseSparkConf import org.apache.hadoop.hbase.util.Bytes import org.apache.hadoop.hbase.{TableName, HBaseTestingUtility} +import org.apache.spark.sql.datasources.hbase.HBaseTableCatalog import org.apache.spark.sql.{DataFrame, SQLContext} import org.apache.spark.{SparkConf, SparkContext, Logging} import org.scalatest.{BeforeAndAfterAll, BeforeAndAfterEach, FunSuite} @@ -137,20 +138,37 @@ BeforeAndAfterEach with BeforeAndAfterAll with Logging { connection.close() } + def hbaseTable1Catalog = s"""{ + |"table":{"namespace":"default", "name":"t1"}, + |"rowkey":"key", + |"columns":{ + |"KEY_FIELD":{"cf":"rowkey", "col":"key", "type":"string"}, + |"A_FIELD":{"cf":"c", "col":"a", "type":"string"}, + |"B_FIELD":{"cf":"c", "col":"b", "type":"string"} + |} + |}""".stripMargin + new HBaseContext(sc, TEST_UTIL.getConfiguration) sqlContext = new SQLContext(sc) df = sqlContext.load("org.apache.hadoop.hbase.spark", - Map("hbase.columns.mapping" -> - "KEY_FIELD STRING :key, A_FIELD STRING c:a, B_FIELD STRING c:b,", - "hbase.table" -> "t1")) + Map(HBaseTableCatalog.tableCatalog->hbaseTable1Catalog)) df.registerTempTable("hbaseTable1") + def hbaseTable2Catalog = s"""{ + |"table":{"namespace":"default", "name":"t2"}, + |"rowkey":"key", + |"columns":{ + |"KEY_FIELD":{"cf":"rowkey", "col":"key", "type":"int"}, + |"A_FIELD":{"cf":"c", "col":"a", "type":"string"}, + |"B_FIELD":{"cf":"c", "col":"b", "type":"string"} + |} + |}""".stripMargin + + df = sqlContext.load("org.apache.hadoop.hbase.spark", - Map("hbase.columns.mapping" -> - "KEY_FIELD INT :key, A_FIELD STRING c:a, B_FIELD STRING c:b,", - "hbase.table" -> "t2")) + Map(HBaseTableCatalog.tableCatalog->hbaseTable2Catalog)) df.registerTempTable("hbaseTable2") } @@ -512,13 +530,20 @@ BeforeAndAfterEach with BeforeAndAfterAll with Logging { assert(scanRange1.isUpperBoundEqualTo) } - test("Test table that doesn't exist") { + val catalog = s"""{ + |"table":{"namespace":"default", "name":"t1NotThere"}, + |"rowkey":"key", + |"columns":{ + |"KEY_FIELD":{"cf":"rowkey", "col":"key", "type":"string"}, + |"A_FIELD":{"cf":"c", "col":"a", "type":"string"}, + |"B_FIELD":{"cf":"c", "col":"c", "type":"string"} + |} + |}""".stripMargin + intercept[Exception] { df = sqlContext.load("org.apache.hadoop.hbase.spark", - Map("hbase.columns.mapping" -> - "KEY_FIELD STRING :key, A_FIELD STRING c:a, B_FIELD STRING c:b,", - "hbase.table" -> "t1NotThere")) + Map(HBaseTableCatalog.tableCatalog->catalog)) df.registerTempTable("hbaseNonExistingTmp") @@ -530,11 +555,20 @@ BeforeAndAfterEach with BeforeAndAfterAll with Logging { DefaultSourceStaticUtils.lastFiveExecutionRules.poll() } + test("Test table with column that doesn't exist") { + val catalog = s"""{ + |"table":{"namespace":"default", "name":"t1"}, + |"rowkey":"key", + |"columns":{ + |"KEY_FIELD":{"cf":"rowkey", "col":"key", "type":"string"}, + |"A_FIELD":{"cf":"c", "col":"a", "type":"string"}, + |"B_FIELD":{"cf":"c", "col":"b", "type":"string"}, + |"C_FIELD":{"cf":"c", "col":"c", "type":"string"} + |} + |}""".stripMargin df = sqlContext.load("org.apache.hadoop.hbase.spark", - Map("hbase.columns.mapping" -> - "KEY_FIELD STRING :key, A_FIELD STRING c:a, B_FIELD STRING c:b, C_FIELD STRING c:c,", - "hbase.table" -> "t1")) + Map(HBaseTableCatalog.tableCatalog->catalog)) df.registerTempTable("hbaseFactColumnTmp") @@ -549,10 +583,18 @@ BeforeAndAfterEach with BeforeAndAfterAll with Logging { } test("Test table with INT column") { + val catalog = s"""{ + |"table":{"namespace":"default", "name":"t1"}, + |"rowkey":"key", + |"columns":{ + |"KEY_FIELD":{"cf":"rowkey", "col":"key", "type":"string"}, + |"A_FIELD":{"cf":"c", "col":"a", "type":"string"}, + |"B_FIELD":{"cf":"c", "col":"b", "type":"string"}, + |"I_FIELD":{"cf":"c", "col":"i", "type":"int"} + |} + |}""".stripMargin df = sqlContext.load("org.apache.hadoop.hbase.spark", - Map("hbase.columns.mapping" -> - "KEY_FIELD STRING :key, A_FIELD STRING c:a, B_FIELD STRING c:b, I_FIELD INT c:i,", - "hbase.table" -> "t1")) + Map(HBaseTableCatalog.tableCatalog->catalog)) df.registerTempTable("hbaseIntTmp") @@ -571,10 +613,18 @@ BeforeAndAfterEach with BeforeAndAfterAll with Logging { } test("Test table with INT column defined at wrong type") { + val catalog = s"""{ + |"table":{"namespace":"default", "name":"t1"}, + |"rowkey":"key", + |"columns":{ + |"KEY_FIELD":{"cf":"rowkey", "col":"key", "type":"string"}, + |"A_FIELD":{"cf":"c", "col":"a", "type":"string"}, + |"B_FIELD":{"cf":"c", "col":"b", "type":"string"}, + |"I_FIELD":{"cf":"c", "col":"i", "type":"string"} + |} + |}""".stripMargin df = sqlContext.load("org.apache.hadoop.hbase.spark", - Map("hbase.columns.mapping" -> - "KEY_FIELD STRING :key, A_FIELD STRING c:a, B_FIELD STRING c:b, I_FIELD STRING c:i,", - "hbase.table" -> "t1")) + Map(HBaseTableCatalog.tableCatalog->catalog)) df.registerTempTable("hbaseIntWrongTypeTmp") @@ -594,32 +644,19 @@ BeforeAndAfterEach with BeforeAndAfterAll with Logging { assert(localResult(0).getString(2).charAt(3).toByte == 1) } - test("Test improperly formatted column mapping") { - intercept[IllegalArgumentException] { - df = sqlContext.load("org.apache.hadoop.hbase.spark", - Map("hbase.columns.mapping" -> - "KEY_FIELD,STRING,:key, A_FIELD,STRING,c:a, B_FIELD,STRING,c:b, I_FIELD,STRING,c:i,", - "hbase.table" -> "t1")) - - df.registerTempTable("hbaseBadTmp") - - val result = sqlContext.sql("SELECT KEY_FIELD, " + - "B_FIELD, I_FIELD FROM hbaseBadTmp") - - val executionRules = DefaultSourceStaticUtils.lastFiveExecutionRules.poll() - assert(executionRules.dynamicLogicExpression == null) - - result.take(5) - } - } - - test("Test bad column type") { - intercept[IllegalArgumentException] { + val catalog = s"""{ + |"table":{"namespace":"default", "name":"t1"}, + |"rowkey":"key", + |"columns":{ + |"KEY_FIELD":{"cf":"rowkey", "col":"key", "type":"FOOBAR"}, + |"A_FIELD":{"cf":"c", "col":"a", "type":"string"}, + |"I_FIELD":{"cf":"c", "col":"i", "type":"string"} + |} + |}""".stripMargin + intercept[Exception] { df = sqlContext.load("org.apache.hadoop.hbase.spark", - Map("hbase.columns.mapping" -> - "KEY_FIELD FOOBAR :key, A_FIELD STRING c:a, B_FIELD STRING c:b, I_FIELD STRING c:i,", - "hbase.table" -> "t1")) + Map(HBaseTableCatalog.tableCatalog->catalog)) df.registerTempTable("hbaseIntWrongTypeTmp") @@ -665,10 +702,18 @@ BeforeAndAfterEach with BeforeAndAfterAll with Logging { } test("Test table with sparse column") { + val catalog = s"""{ + |"table":{"namespace":"default", "name":"t1"}, + |"rowkey":"key", + |"columns":{ + |"KEY_FIELD":{"cf":"rowkey", "col":"key", "type":"string"}, + |"A_FIELD":{"cf":"c", "col":"a", "type":"string"}, + |"B_FIELD":{"cf":"c", "col":"b", "type":"string"}, + |"Z_FIELD":{"cf":"c", "col":"z", "type":"string"} + |} + |}""".stripMargin df = sqlContext.load("org.apache.hadoop.hbase.spark", - Map("hbase.columns.mapping" -> - "KEY_FIELD STRING :key, A_FIELD STRING c:a, B_FIELD STRING c:b, Z_FIELD STRING c:z,", - "hbase.table" -> "t1")) + Map(HBaseTableCatalog.tableCatalog->catalog)) df.registerTempTable("hbaseZTmp") @@ -688,11 +733,19 @@ BeforeAndAfterEach with BeforeAndAfterAll with Logging { } test("Test with column logic disabled") { + val catalog = s"""{ + |"table":{"namespace":"default", "name":"t1"}, + |"rowkey":"key", + |"columns":{ + |"KEY_FIELD":{"cf":"rowkey", "col":"key", "type":"string"}, + |"A_FIELD":{"cf":"c", "col":"a", "type":"string"}, + |"B_FIELD":{"cf":"c", "col":"b", "type":"string"}, + |"Z_FIELD":{"cf":"c", "col":"z", "type":"string"} + |} + |}""".stripMargin df = sqlContext.load("org.apache.hadoop.hbase.spark", - Map("hbase.columns.mapping" -> - "KEY_FIELD STRING :key, A_FIELD STRING c:a, B_FIELD STRING c:b, Z_FIELD STRING c:z,", - "hbase.table" -> "t1", - "hbase.push.down.column.filter" -> "false")) + Map(HBaseTableCatalog.tableCatalog->catalog, + HBaseSparkConf.PUSH_DOWN_COLUMN_FILTER -> "false")) df.registerTempTable("hbaseNoPushDownTmp") diff --git hbase-spark/src/test/scala/org/apache/hadoop/hbase/spark/HBaseCatalogSuite.scala hbase-spark/src/test/scala/org/apache/hadoop/hbase/spark/HBaseCatalogSuite.scala new file mode 100644 index 0000000..d83aad4 --- /dev/null +++ hbase-spark/src/test/scala/org/apache/hadoop/hbase/spark/HBaseCatalogSuite.scala @@ -0,0 +1,107 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.hadoop.hbase.spark + +import org.apache.hadoop.hbase.spark.datasources.{DoubleSerDes, SerDes} +import org.apache.hadoop.hbase.util.Bytes +import org.apache.spark.Logging +import org.apache.spark.sql.datasources.hbase.{DataTypeParserWrapper, HBaseTableCatalog} +import org.apache.spark.sql.types._ +import org.scalatest.{BeforeAndAfterAll, BeforeAndAfterEach, FunSuite} + +class HBaseCatalogSuite extends FunSuite with BeforeAndAfterEach with BeforeAndAfterAll with Logging { + + val map = s"""MAP>""" + val array = s"""array>""" + val arrayMap = s"""MAp>""" + val catalog = s"""{ + |"table":{"namespace":"default", "name":"htable"}, + |"rowkey":"key1:key2", + |"columns":{ + |"col1":{"cf":"rowkey", "col":"key1", "type":"string"}, + |"col2":{"cf":"rowkey", "col":"key2", "type":"double"}, + |"col3":{"cf":"cf1", "col":"col2", "type":"binary"}, + |"col4":{"cf":"cf1", "col":"col3", "type":"timestamp"}, + |"col5":{"cf":"cf1", "col":"col4", "type":"double", "serdes":"${classOf[DoubleSerDes].getName}"}, + |"col6":{"cf":"cf1", "col":"col5", "type":"$map"}, + |"col7":{"cf":"cf1", "col":"col6", "type":"$array"}, + |"col8":{"cf":"cf1", "col":"col7", "type":"$arrayMap"} + |} + |}""".stripMargin + val parameters = Map(HBaseTableCatalog.tableCatalog->catalog) + val t = HBaseTableCatalog(parameters) + + def checkDataType(dataTypeString: String, expectedDataType: DataType): Unit = { + test(s"parse ${dataTypeString.replace("\n", "")}") { + assert(DataTypeParserWrapper.parse(dataTypeString) === expectedDataType) + } + } + test("basic") { + assert(t.getField("col1").isRowKey == true) + assert(t.getPrimaryKey == "key1") + assert(t.getField("col3").dt == BinaryType) + assert(t.getField("col4").dt == TimestampType) + assert(t.getField("col5").dt == DoubleType) + assert(t.getField("col5").serdes != None) + assert(t.getField("col4").serdes == None) + assert(t.getField("col1").isRowKey) + assert(t.getField("col2").isRowKey) + assert(!t.getField("col3").isRowKey) + assert(t.getField("col2").length == Bytes.SIZEOF_DOUBLE) + assert(t.getField("col1").length == -1) + assert(t.getField("col8").length == -1) + } + + checkDataType( + map, + t.getField("col6").dt + ) + + checkDataType( + array, + t.getField("col7").dt + ) + + checkDataType( + arrayMap, + t.getField("col8").dt + ) + + test("convert") { + val m = Map("hbase.columns.mapping" -> + "KEY_FIELD STRING :key, A_FIELD STRING c:a, B_FIELD STRING c:b,", + "hbase.table" -> "t1") + val map = HBaseTableCatalog.convert(m) + val json = map.get(HBaseTableCatalog.tableCatalog).get + val parameters = Map(HBaseTableCatalog.tableCatalog->json) + val t = HBaseTableCatalog(parameters) + assert(t.getField("KEY_FIELD").isRowKey) + assert(DataTypeParserWrapper.parse("STRING") === t.getField("A_FIELD").dt) + assert(!t.getField("A_FIELD").isRowKey) + } + + test("compatiblity") { + val m = Map("hbase.columns.mapping" -> + "KEY_FIELD STRING :key, A_FIELD STRING c:a, B_FIELD STRING c:b,", + "hbase.table" -> "t1") + val t = HBaseTableCatalog(m) + assert(t.getField("KEY_FIELD").isRowKey) + assert(DataTypeParserWrapper.parse("STRING") === t.getField("A_FIELD").dt) + assert(!t.getField("A_FIELD").isRowKey) + } +}