diff --git hbase-spark/pom.xml hbase-spark/pom.xml
index 7767440..776eadb 100644
--- hbase-spark/pom.xml
+++ hbase-spark/pom.xml
@@ -42,6 +42,8 @@
2.10
true
${project.basedir}/..
+ 1.7.6
+
@@ -527,6 +529,11 @@
test-jar
test
+
+ org.apache.avro
+ avro
+ ${avro.version}
+
diff --git hbase-spark/src/main/scala/org/apache/hadoop/hbase/spark/DefaultSource.scala hbase-spark/src/main/scala/org/apache/hadoop/hbase/spark/DefaultSource.scala
index 6a6bc1a..6852aa7 100644
--- hbase-spark/src/main/scala/org/apache/hadoop/hbase/spark/DefaultSource.scala
+++ hbase-spark/src/main/scala/org/apache/hadoop/hbase/spark/DefaultSource.scala
@@ -23,17 +23,16 @@ import java.util.concurrent.ConcurrentLinkedQueue
import org.apache.hadoop.hbase.client._
import org.apache.hadoop.hbase.io.ImmutableBytesWritable
import org.apache.hadoop.hbase.mapred.TableOutputFormat
-import org.apache.hadoop.hbase.spark.datasources.Utils
import org.apache.hadoop.hbase.spark.datasources.HBaseSparkConf
import org.apache.hadoop.hbase.spark.datasources.HBaseTableScanRDD
import org.apache.hadoop.hbase.spark.datasources.SerializableConfiguration
import org.apache.hadoop.hbase.types._
import org.apache.hadoop.hbase.util.{Bytes, PositionedByteRange, SimplePositionedMutableByteRange}
-import org.apache.hadoop.hbase.{HColumnDescriptor, HTableDescriptor, HBaseConfiguration, TableName}
+import org.apache.hadoop.hbase._
import org.apache.hadoop.mapred.JobConf
import org.apache.spark.Logging
import org.apache.spark.rdd.RDD
-import org.apache.spark.sql.datasources.hbase.{Field, HBaseTableCatalog}
+import org.apache.spark.sql.datasources.hbase.{Utils, Field, HBaseTableCatalog}
import org.apache.spark.sql.{DataFrame, SaveMode, Row, SQLContext}
import org.apache.spark.sql.sources._
import org.apache.spark.sql.types._
@@ -217,6 +216,62 @@ case class HBaseRelation (
rdd.map(convertToPut(_)).saveAsHadoopDataset(jobConfig)
}
+ def getIndexedProjections(requiredColumns: Array[String]): Seq[(Field, Int)] = {
+ requiredColumns.map(catalog.sMap.getField(_)).zipWithIndex
+ }
+
+
+ /**
+ * Takes a HBase Row object and parses all of the fields from it.
+ * This is independent of which fields were requested from the key
+ * Because we have all the data it's less complex to parse everything.
+ *
+ * @param keyFields all of the fields in the row key, ORDERED by their order in the row key.
+ */
+ def parseRowKey(row: Array[Byte], keyFields: Seq[Field]): Map[Field, Any] = {
+ keyFields.foldLeft((0, Seq[(Field, Any)]()))((state, field) => {
+ val idx = state._1
+ val parsed = state._2
+ if (field.length != -1) {
+ val value = Utils.hbaseFieldToScalaType(field, row, idx, field.length)
+ // Return the new index and appended value
+ (idx + field.length, parsed ++ Seq((field, value)))
+ } else {
+ field.dt match {
+ case StringType =>
+ val pos = row.indexOf(HBaseTableCatalog.delimiter, idx)
+ if (pos == -1 || pos > row.length) {
+ // this is at the last dimension
+ val value = Utils.hbaseFieldToScalaType(field, row, idx, row.length)
+ (row.length + 1, parsed ++ Seq((field, value)))
+ } else {
+ val value = Utils.hbaseFieldToScalaType(field, row, idx, pos - idx)
+ (pos, parsed ++ Seq((field, value)))
+ }
+ // We don't know the length, assume it extend to the end of the rowkey.
+ case _ => (row.length + 1, parsed ++ Seq((field, Utils.hbaseFieldToScalaType(field, row, idx, row.length))))
+ }
+ }
+ })._2.toMap
+ }
+
+ def buildRow(fields: Seq[Field], result: Result): Row = {
+ val r = result.getRow
+ val keySeq = parseRowKey(r, catalog.getRowKey)
+ val valueSeq = fields.filter(!_.isRowKey).map { x =>
+ val kv = result.getColumnLatestCell(Bytes.toBytes(x.cf), Bytes.toBytes(x.col))
+ if (kv == null || kv.getValueLength == 0) {
+ (x, null)
+ } else {
+ val v = CellUtil.cloneValue(kv)
+ (x, Utils.hbaseFieldToScalaType(x, v, 0, v.length))
+ }
+ }.toMap
+ val unionedRow = keySeq ++ valueSeq
+ // Return the row ordered by the requested order
+ Row.fromSeq(fields.map(unionedRow.get(_).getOrElse(null)))
+ }
+
/**
* Here we are building the functionality to populate the resulting RDD[Row]
* Here is where we will do the following:
@@ -281,10 +336,12 @@ case class HBaseRelation (
val hRdd = new HBaseTableScanRDD(this, hbaseContext, pushDownFilterJava, requiredQualifierDefinitionList.seq)
pushDownRowKeyFilter.points.foreach(hRdd.addPoint(_))
pushDownRowKeyFilter.ranges.foreach(hRdd.addRange(_))
+
var resultRDD: RDD[Row] = {
val tmp = hRdd.map{ r =>
- Row.fromSeq(requiredColumns.map(c =>
- DefaultSourceStaticUtils.getValue(catalog.getField(c), r)))
+ val indexedFields = getIndexedProjections(requiredColumns).map(_._1)
+ buildRow(indexedFields, r)
+
}
if (tmp.partitions.size > 0) {
tmp
@@ -302,7 +359,8 @@ case class HBaseRelation (
scan.addColumn(d.cfBytes, d.colBytes))
val rdd = hbaseContext.hbaseRDD(TableName.valueOf(tableName), scan).map(r => {
- Row.fromSeq(requiredColumns.map(c => DefaultSourceStaticUtils.getValue(catalog.getField(c), r._2)))
+ val indexedFields = getIndexedProjections(requiredColumns).map(_._1)
+ buildRow(indexedFields, r._2)
})
resultRDD=rdd
}
diff --git hbase-spark/src/main/scala/org/apache/hadoop/hbase/spark/datasources/SchemaConverters.scala hbase-spark/src/main/scala/org/apache/hadoop/hbase/spark/datasources/SchemaConverters.scala
new file mode 100644
index 0000000..3db7f9c
--- /dev/null
+++ hbase-spark/src/main/scala/org/apache/hadoop/hbase/spark/datasources/SchemaConverters.scala
@@ -0,0 +1,427 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.hadoop.hbase.spark
+
+import java.io.ByteArrayInputStream
+import java.nio.ByteBuffer
+import java.sql.Timestamp
+import java.util
+import java.util.HashMap
+
+import org.apache.avro.SchemaBuilder.BaseFieldTypeBuilder
+import org.apache.avro.SchemaBuilder.BaseTypeBuilder
+import org.apache.avro.SchemaBuilder.FieldAssembler
+import org.apache.avro.SchemaBuilder.FieldDefault
+import org.apache.avro.SchemaBuilder.RecordBuilder
+import org.apache.avro.io._
+import org.apache.commons.io.output.ByteArrayOutputStream
+import org.apache.hadoop.hbase.util.Bytes
+
+import scala.collection.JavaConversions._
+
+import org.apache.avro.{SchemaBuilder, Schema}
+import org.apache.avro.Schema.Type._
+import org.apache.avro.SchemaBuilder._
+import org.apache.avro.generic.GenericData.{Record, Fixed}
+import org.apache.avro.generic.{GenericDatumReader, GenericDatumWriter, GenericData, GenericRecord}
+import org.apache.spark.sql.Row
+import org.apache.spark.sql.types._
+
+import scala.collection.immutable.Map
+
+
+abstract class AvroException(msg: String) extends Exception(msg)
+case class SchemaConversionException(msg: String) extends AvroException(msg)
+
+/***
+ * On top level, the converters provide three high level interface.
+ * 1. toSqlType: This function takes an avro schema and returns a sql schema.
+ * 2. createConverterToSQL: Returns a function that is used to convert avro types to their
+ * corresponding sparkSQL representations.
+ * 3. convertTypeToAvro: This function constructs converter function for a given sparkSQL datatype. This is used in
+ * writing Avro records out to disk
+ */
+object SchemaConverters {
+
+ case class SchemaType(dataType: DataType, nullable: Boolean)
+
+ /**
+ * This function takes an avro schema and returns a sql schema.
+ */
+ def toSqlType(avroSchema: Schema): SchemaType = {
+ avroSchema.getType match {
+ case INT => SchemaType(IntegerType, nullable = false)
+ case STRING => SchemaType(StringType, nullable = false)
+ case BOOLEAN => SchemaType(BooleanType, nullable = false)
+ case BYTES => SchemaType(BinaryType, nullable = false)
+ case DOUBLE => SchemaType(DoubleType, nullable = false)
+ case FLOAT => SchemaType(FloatType, nullable = false)
+ case LONG => SchemaType(LongType, nullable = false)
+ case FIXED => SchemaType(BinaryType, nullable = false)
+ case ENUM => SchemaType(StringType, nullable = false)
+
+ case RECORD =>
+ val fields = avroSchema.getFields.map { f =>
+ val schemaType = toSqlType(f.schema())
+ StructField(f.name, schemaType.dataType, schemaType.nullable)
+ }
+
+ SchemaType(StructType(fields), nullable = false)
+
+ case ARRAY =>
+ val schemaType = toSqlType(avroSchema.getElementType)
+ SchemaType(
+ ArrayType(schemaType.dataType, containsNull = schemaType.nullable),
+ nullable = false)
+
+ case MAP =>
+ val schemaType = toSqlType(avroSchema.getValueType)
+ SchemaType(
+ MapType(StringType, schemaType.dataType, valueContainsNull = schemaType.nullable),
+ nullable = false)
+
+ case UNION =>
+ if (avroSchema.getTypes.exists(_.getType == NULL)) {
+ // In case of a union with null, eliminate it and make a recursive call
+ val remainingUnionTypes = avroSchema.getTypes.filterNot(_.getType == NULL)
+ if (remainingUnionTypes.size == 1) {
+ toSqlType(remainingUnionTypes.get(0)).copy(nullable = true)
+ } else {
+ toSqlType(Schema.createUnion(remainingUnionTypes)).copy(nullable = true)
+ }
+ } else avroSchema.getTypes.map(_.getType) match {
+ case Seq(t1, t2) if Set(t1, t2) == Set(INT, LONG) =>
+ SchemaType(LongType, nullable = false)
+ case Seq(t1, t2) if Set(t1, t2) == Set(FLOAT, DOUBLE) =>
+ SchemaType(DoubleType, nullable = false)
+ case other => throw new SchemaConversionException(
+ s"This mix of union types is not supported (see README): $other")
+ }
+
+ case other => throw new SchemaConversionException(s"Unsupported type $other")
+ }
+ }
+
+ /**
+ * This function converts sparkSQL StructType into avro schema. This method uses two other
+ * converter methods in order to do the conversion.
+ */
+ private def convertStructToAvro[T](
+ structType: StructType,
+ schemaBuilder: RecordBuilder[T],
+ recordNamespace: String): T = {
+ val fieldsAssembler: FieldAssembler[T] = schemaBuilder.fields()
+ structType.fields.foreach { field =>
+ val newField = fieldsAssembler.name(field.name).`type`()
+
+ if (field.nullable) {
+ convertFieldTypeToAvro(field.dataType, newField.nullable(), field.name, recordNamespace)
+ .noDefault
+ } else {
+ convertFieldTypeToAvro(field.dataType, newField, field.name, recordNamespace)
+ .noDefault
+ }
+ }
+ fieldsAssembler.endRecord()
+ }
+
+ /**
+ * Returns a function that is used to convert avro types to their
+ * corresponding sparkSQL representations.
+ */
+ def createConverterToSQL(schema: Schema): Any => Any = {
+ schema.getType match {
+ // Avro strings are in Utf8, so we have to call toString on them
+ case STRING | ENUM => (item: Any) => if (item == null) null else item.toString
+ case INT | BOOLEAN | DOUBLE | FLOAT | LONG => identity
+ // Byte arrays are reused by avro, so we have to make a copy of them.
+ case FIXED => (item: Any) => if (item == null) {
+ null
+ } else {
+ item.asInstanceOf[Fixed].bytes().clone()
+ }
+ case BYTES => (item: Any) => if (item == null) {
+ null
+ } else {
+ val bytes = item.asInstanceOf[ByteBuffer]
+ val javaBytes = new Array[Byte](bytes.remaining)
+ bytes.get(javaBytes)
+ javaBytes
+ }
+ case RECORD =>
+ val fieldConverters = schema.getFields.map(f => createConverterToSQL(f.schema))
+ (item: Any) => if (item == null) {
+ null
+ } else {
+ val record = item.asInstanceOf[GenericRecord]
+ val converted = new Array[Any](fieldConverters.size)
+ var idx = 0
+ while (idx < fieldConverters.size) {
+ converted(idx) = fieldConverters.apply(idx)(record.get(idx))
+ idx += 1
+ }
+ Row.fromSeq(converted.toSeq)
+ }
+ case ARRAY =>
+ val elementConverter = createConverterToSQL(schema.getElementType)
+ (item: Any) => if (item == null) {
+ null
+ } else {
+ try {
+ item.asInstanceOf[GenericData.Array[Any]].map(elementConverter)
+ } catch {
+ case e: Throwable =>
+ item.asInstanceOf[util.ArrayList[Any]].map(elementConverter)
+ }
+ }
+ case MAP =>
+ val valueConverter = createConverterToSQL(schema.getValueType)
+ (item: Any) => if (item == null) {
+ null
+ } else {
+ item.asInstanceOf[HashMap[Any, Any]].map(x => (x._1.toString, valueConverter(x._2))).toMap
+ }
+ case UNION =>
+ if (schema.getTypes.exists(_.getType == NULL)) {
+ val remainingUnionTypes = schema.getTypes.filterNot(_.getType == NULL)
+ if (remainingUnionTypes.size == 1) {
+ createConverterToSQL(remainingUnionTypes.get(0))
+ } else {
+ createConverterToSQL(Schema.createUnion(remainingUnionTypes))
+ }
+ } else schema.getTypes.map(_.getType) match {
+ case Seq(t1, t2) if Set(t1, t2) == Set(INT, LONG) =>
+ (item: Any) => {
+ item match {
+ case l: Long => l
+ case i: Int => i.toLong
+ case null => null
+ }
+ }
+ case Seq(t1, t2) if Set(t1, t2) == Set(FLOAT, DOUBLE) =>
+ (item: Any) => {
+ item match {
+ case d: Double => d
+ case f: Float => f.toDouble
+ case null => null
+ }
+ }
+ case other => throw new SchemaConversionException(
+ s"This mix of union types is not supported (see README): $other")
+ }
+ case other => throw new SchemaConversionException(s"invalid avro type: $other")
+ }
+ }
+
+ /**
+ * This function is used to convert some sparkSQL type to avro type. Note that this function won't
+ * be used to construct fields of avro record (convertFieldTypeToAvro is used for that).
+ */
+ private def convertTypeToAvro[T](
+ dataType: DataType,
+ schemaBuilder: BaseTypeBuilder[T],
+ structName: String,
+ recordNamespace: String): T = {
+ dataType match {
+ case ByteType => schemaBuilder.intType()
+ case ShortType => schemaBuilder.intType()
+ case IntegerType => schemaBuilder.intType()
+ case LongType => schemaBuilder.longType()
+ case FloatType => schemaBuilder.floatType()
+ case DoubleType => schemaBuilder.doubleType()
+ case _: DecimalType => schemaBuilder.stringType()
+ case StringType => schemaBuilder.stringType()
+ case BinaryType => schemaBuilder.bytesType()
+ case BooleanType => schemaBuilder.booleanType()
+ case TimestampType => schemaBuilder.longType()
+
+ case ArrayType(elementType, _) =>
+ val builder = getSchemaBuilder(dataType.asInstanceOf[ArrayType].containsNull)
+ val elementSchema = convertTypeToAvro(elementType, builder, structName, recordNamespace)
+ schemaBuilder.array().items(elementSchema)
+
+ case MapType(StringType, valueType, _) =>
+ val builder = getSchemaBuilder(dataType.asInstanceOf[MapType].valueContainsNull)
+ val valueSchema = convertTypeToAvro(valueType, builder, structName, recordNamespace)
+ schemaBuilder.map().values(valueSchema)
+
+ case structType: StructType =>
+ convertStructToAvro(
+ structType,
+ schemaBuilder.record(structName).namespace(recordNamespace),
+ recordNamespace)
+
+ case other => throw new IllegalArgumentException(s"Unexpected type $dataType.")
+ }
+ }
+
+ /**
+ * This function is used to construct fields of the avro record, where schema of the field is
+ * specified by avro representation of dataType. Since builders for record fields are different
+ * from those for everything else, we have to use a separate method.
+ */
+ private def convertFieldTypeToAvro[T](
+ dataType: DataType,
+ newFieldBuilder: BaseFieldTypeBuilder[T],
+ structName: String,
+ recordNamespace: String): FieldDefault[T, _] = {
+ dataType match {
+ case ByteType => newFieldBuilder.intType()
+ case ShortType => newFieldBuilder.intType()
+ case IntegerType => newFieldBuilder.intType()
+ case LongType => newFieldBuilder.longType()
+ case FloatType => newFieldBuilder.floatType()
+ case DoubleType => newFieldBuilder.doubleType()
+ case _: DecimalType => newFieldBuilder.stringType()
+ case StringType => newFieldBuilder.stringType()
+ case BinaryType => newFieldBuilder.bytesType()
+ case BooleanType => newFieldBuilder.booleanType()
+ case TimestampType => newFieldBuilder.longType()
+
+ case ArrayType(elementType, _) =>
+ val builder = getSchemaBuilder(dataType.asInstanceOf[ArrayType].containsNull)
+ val elementSchema = convertTypeToAvro(elementType, builder, structName, recordNamespace)
+ newFieldBuilder.array().items(elementSchema)
+
+ case MapType(StringType, valueType, _) =>
+ val builder = getSchemaBuilder(dataType.asInstanceOf[MapType].valueContainsNull)
+ val valueSchema = convertTypeToAvro(valueType, builder, structName, recordNamespace)
+ newFieldBuilder.map().values(valueSchema)
+
+ case structType: StructType =>
+ convertStructToAvro(
+ structType,
+ newFieldBuilder.record(structName).namespace(recordNamespace),
+ recordNamespace)
+
+ case other => throw new IllegalArgumentException(s"Unexpected type $dataType.")
+ }
+ }
+
+ private def getSchemaBuilder(isNullable: Boolean): BaseTypeBuilder[Schema] = {
+ if (isNullable) {
+ SchemaBuilder.builder().nullable()
+ } else {
+ SchemaBuilder.builder()
+ }
+ }
+ /**
+ * This function constructs converter function for a given sparkSQL datatype. This is used in
+ * writing Avro records out to disk
+ */
+ def createConverterToAvro(
+ dataType: DataType,
+ structName: String,
+ recordNamespace: String): (Any) => Any = {
+ dataType match {
+ case BinaryType => (item: Any) => item match {
+ case null => null
+ case bytes: Array[Byte] => ByteBuffer.wrap(bytes)
+ }
+ case ByteType | ShortType | IntegerType | LongType |
+ FloatType | DoubleType | StringType | BooleanType => identity
+ case _: DecimalType => (item: Any) => if (item == null) null else item.toString
+ case TimestampType => (item: Any) =>
+ if (item == null) null else item.asInstanceOf[Timestamp].getTime
+ case ArrayType(elementType, _) =>
+ val elementConverter = createConverterToAvro(elementType, structName, recordNamespace)
+ (item: Any) => {
+ if (item == null) {
+ null
+ } else {
+ val sourceArray = item.asInstanceOf[Seq[Any]]
+ val sourceArraySize = sourceArray.size
+ val targetArray = new util.ArrayList[Any](sourceArraySize)
+ var idx = 0
+ while (idx < sourceArraySize) {
+ targetArray.add(elementConverter(sourceArray(idx)))
+ idx += 1
+ }
+ targetArray
+ }
+ }
+ case MapType(StringType, valueType, _) =>
+ val valueConverter = createConverterToAvro(valueType, structName, recordNamespace)
+ (item: Any) => {
+ if (item == null) {
+ null
+ } else {
+ val javaMap = new HashMap[String, Any]()
+ item.asInstanceOf[Map[String, Any]].foreach { case (key, value) =>
+ javaMap.put(key, valueConverter(value))
+ }
+ javaMap
+ }
+ }
+ case structType: StructType =>
+ val builder = SchemaBuilder.record(structName).namespace(recordNamespace)
+ val schema: Schema = SchemaConverters.convertStructToAvro(
+ structType, builder, recordNamespace)
+ val fieldConverters = structType.fields.map(field =>
+ createConverterToAvro(field.dataType, field.name, recordNamespace))
+ (item: Any) => {
+ if (item == null) {
+ null
+ } else {
+ val record = new Record(schema)
+ val convertersIterator = fieldConverters.iterator
+ val fieldNamesIterator = dataType.asInstanceOf[StructType].fieldNames.iterator
+ val rowIterator = item.asInstanceOf[Row].toSeq.iterator
+
+ while (convertersIterator.hasNext) {
+ val converter = convertersIterator.next()
+ record.put(fieldNamesIterator.next(), converter(rowIterator.next()))
+ }
+ record
+ }
+ }
+ }
+ }
+}
+
+
+object AvroSedes {
+ // We only handle top level is record or primary type now
+ def serialize(input: Any, schema: Schema): Array[Byte]= {
+ schema.getType match {
+ case BOOLEAN => Bytes.toBytes(input.asInstanceOf[Boolean])
+ case BYTES | FIXED=> input.asInstanceOf[Array[Byte]]
+ case DOUBLE => Bytes.toBytes(input.asInstanceOf[Double])
+ case FLOAT => Bytes.toBytes(input.asInstanceOf[Float])
+ case INT => Bytes.toBytes(input.asInstanceOf[Int])
+ case LONG => Bytes.toBytes(input.asInstanceOf[Long])
+ case STRING => Bytes.toBytes(input.asInstanceOf[String])
+ case RECORD =>
+ val gr = input.asInstanceOf[GenericRecord]
+ val writer2 = new GenericDatumWriter[GenericRecord](schema)
+ val bao2 = new ByteArrayOutputStream()
+ val encoder2: BinaryEncoder = EncoderFactory.get().directBinaryEncoder(bao2, null)
+ writer2.write(gr, encoder2)
+ bao2.toByteArray()
+ case _ => throw new Exception(s"unsupported data type ${schema.getType}") //TODO
+ }
+ }
+
+ def deserialize(input: Array[Byte], schema: Schema): GenericRecord = {
+ val reader2: DatumReader[GenericRecord] = new GenericDatumReader[GenericRecord](schema)
+ val bai2 = new ByteArrayInputStream(input)
+ val decoder2: BinaryDecoder = DecoderFactory.get().directBinaryDecoder(bai2, null)
+ val gr2: GenericRecord = reader2.read(null, decoder2)
+ gr2
+ }
+}
diff --git hbase-spark/src/main/scala/org/apache/hadoop/hbase/spark/datasources/Utils.scala hbase-spark/src/main/scala/org/apache/hadoop/hbase/spark/datasources/Utils.scala
deleted file mode 100644
index 090e81a..0000000
--- hbase-spark/src/main/scala/org/apache/hadoop/hbase/spark/datasources/Utils.scala
+++ /dev/null
@@ -1,44 +0,0 @@
-
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.hadoop.hbase.spark.datasources
-
-import org.apache.hadoop.hbase.util.Bytes
-import org.apache.hadoop.hbase.util.Bytes
-import org.apache.spark.sql.datasources.hbase.Field
-import org.apache.spark.unsafe.types.UTF8String
-
-object Utils {
- // convert input to data type
- def toBytes(input: Any, field: Field): Array[Byte] = {
- input match {
- case data: Boolean => Bytes.toBytes(data)
- case data: Byte => Array(data)
- case data: Array[Byte] => data
- case data: Double => Bytes.toBytes(data)
- case data: Float => Bytes.toBytes(data)
- case data: Int => Bytes.toBytes(data)
- case data: Long => Bytes.toBytes(data)
- case data: Short => Bytes.toBytes(data)
- case data: UTF8String => data.getBytes
- case data: String => Bytes.toBytes(data)
- // TODO: add more data type support
- case _ => throw new Exception(s"unsupported data type ${field.dt}")
- }
- }
-}
diff --git hbase-spark/src/main/scala/org/apache/spark/sql/datasources/hbase/HBaseTableCatalog.scala hbase-spark/src/main/scala/org/apache/spark/sql/datasources/hbase/HBaseTableCatalog.scala
index 45fa60f..831c7de 100644
--- hbase-spark/src/main/scala/org/apache/spark/sql/datasources/hbase/HBaseTableCatalog.scala
+++ hbase-spark/src/main/scala/org/apache/spark/sql/datasources/hbase/HBaseTableCatalog.scala
@@ -17,6 +17,8 @@
package org.apache.spark.sql.datasources.hbase
+import org.apache.avro.Schema
+import org.apache.hadoop.hbase.spark.SchemaConverters
import org.apache.hadoop.hbase.spark.datasources._
import org.apache.hadoop.hbase.spark.hbase._
import org.apache.hadoop.hbase.util.Bytes
@@ -41,6 +43,23 @@ case class Field(
override def toString = s"$colName $cf $col"
val isRowKey = cf == HBaseTableCatalog.rowKey
var start: Int = _
+ def schema: Option[Schema] = avroSchema.map { x =>
+ logDebug(s"avro: $x")
+ val p = new Schema.Parser
+ p.parse(x)
+ }
+
+ lazy val exeSchema = schema
+
+ // converter from avro to catalyst structure
+ lazy val avroToCatalyst: Option[Any => Any] = {
+ schema.map(SchemaConverters.createConverterToSQL(_))
+ }
+
+ // converter from catalyst to avro
+ lazy val catalystToAvro: (Any) => Any ={
+ SchemaConverters.createConverterToAvro(dt, colName, "recordNamespace")
+ }
def cfBytes: Array[Byte] = {
if (isRowKey) {
@@ -58,7 +77,11 @@ case class Field(
}
val dt = {
- sType.map(DataTypeParser.parse(_)).get
+ sType.map(DataTypeParser.parse(_)).getOrElse{
+ schema.map{ x=>
+ SchemaConverters.toSqlType(x).dataType
+ }.get
+ }
}
var length: Int = {
diff --git hbase-spark/src/main/scala/org/apache/spark/sql/datasources/hbase/Utils.scala hbase-spark/src/main/scala/org/apache/spark/sql/datasources/hbase/Utils.scala
new file mode 100644
index 0000000..0d13576
--- /dev/null
+++ hbase-spark/src/main/scala/org/apache/spark/sql/datasources/hbase/Utils.scala
@@ -0,0 +1,97 @@
+
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.datasources.hbase
+
+import org.apache.hadoop.hbase.spark.AvroSedes
+import org.apache.hadoop.hbase.util.Bytes
+import org.apache.spark.sql.execution.SparkSqlSerializer
+import org.apache.spark.sql.types._
+import org.apache.spark.unsafe.types.UTF8String
+
+object Utils {
+
+
+ /**
+ * Parses the hbase field to it's corresponding
+ * scala type which can then be put into a Spark GenericRow
+ * which is then automatically converted by Spark.
+ */
+ def hbaseFieldToScalaType(
+ f: Field,
+ src: Array[Byte],
+ offset: Int,
+ length: Int): Any = {
+ if (f.exeSchema.isDefined) {
+ // If we have avro schema defined, use it to get record, and then covert them to catalyst data type
+ val m = AvroSedes.deserialize(src, f.exeSchema.get)
+ // println(m)
+ val n = f.avroToCatalyst.map(_(m))
+ n.get
+ } else {
+ // Fall back to atomic type
+ f.dt match {
+ case BooleanType => toBoolean(src, offset)
+ case ByteType => src(offset)
+ case DoubleType => Bytes.toDouble(src, offset)
+ case FloatType => Bytes.toFloat(src, offset)
+ case IntegerType => Bytes.toInt(src, offset)
+ case LongType|TimestampType => Bytes.toLong(src, offset)
+ case ShortType => Bytes.toShort(src, offset)
+ case StringType => toUTF8String(src, offset, length)
+ case BinaryType =>
+ val newArray = new Array[Byte](length)
+ System.arraycopy(src, offset, newArray, 0, length)
+ newArray
+ case _ => SparkSqlSerializer.deserialize[Any](src) //TODO
+ }
+ }
+ }
+
+ // convert input to data type
+ def toBytes(input: Any, field: Field): Array[Byte] = {
+ if (field.schema.isDefined) {
+ // Here we assume the top level type is structType
+ val record = field.catalystToAvro(input)
+ AvroSedes.serialize(record, field.schema.get)
+ } else {
+ input match {
+ case data: Boolean => Bytes.toBytes(data)
+ case data: Byte => Array(data)
+ case data: Array[Byte] => data
+ case data: Double => Bytes.toBytes(data)
+ case data: Float => Bytes.toBytes(data)
+ case data: Int => Bytes.toBytes(data)
+ case data: Long => Bytes.toBytes(data)
+ case data: Short => Bytes.toBytes(data)
+ case data: UTF8String => data.getBytes
+ case data: String => Bytes.toBytes(data)
+ // TODO: add more data type support
+ case _ => throw new Exception(s"unsupported data type ${field.dt}")
+ }
+ }
+ }
+
+ def toBoolean(input: Array[Byte], offset: Int): Boolean = {
+ input(offset) != 0
+ }
+
+ def toUTF8String(input: Array[Byte], offset: Int, length: Int): UTF8String = {
+ UTF8String.fromBytes(input.slice(offset, offset + length))
+ }
+}
diff --git hbase-spark/src/test/scala/org/apache/hadoop/hbase/spark/DefaultSourceSuite.scala hbase-spark/src/test/scala/org/apache/hadoop/hbase/spark/DefaultSourceSuite.scala
index a2aa3c6..2ef35ff 100644
--- hbase-spark/src/test/scala/org/apache/hadoop/hbase/spark/DefaultSourceSuite.scala
+++ hbase-spark/src/test/scala/org/apache/hadoop/hbase/spark/DefaultSourceSuite.scala
@@ -17,6 +17,8 @@
package org.apache.hadoop.hbase.spark
+import org.apache.avro.Schema
+import org.apache.avro.generic.GenericData
import org.apache.hadoop.hbase.client.{Put, ConnectionFactory}
import org.apache.hadoop.hbase.spark.datasources.HBaseSparkConf
import org.apache.hadoop.hbase.util.Bytes
@@ -46,6 +48,33 @@ object HBaseRecord {
}
}
+
+case class AvroHBaseKeyRecord(col0: Array[Byte],
+ col1: Array[Byte])
+
+object AvroHBaseKeyRecord {
+ val schemaString =
+ s"""{"namespace": "example.avro",
+ | "type": "record", "name": "User",
+ | "fields": [ {"name": "name", "type": "string"},
+ | {"name": "favorite_number", "type": ["int", "null"]},
+ | {"name": "favorite_color", "type": ["string", "null"]} ] }""".stripMargin
+
+ val avroSchema: Schema = {
+ val p = new Schema.Parser
+ p.parse(schemaString)
+ }
+
+ def apply(i: Int): AvroHBaseKeyRecord = {
+ val user = new GenericData.Record(avroSchema);
+ user.put("name", s"name${"%03d".format(i)}")
+ user.put("favorite_number", i)
+ user.put("favorite_color", s"color${"%03d".format(i)}")
+ val avroByte = AvroSedes.serialize(user, avroSchema)
+ AvroHBaseKeyRecord(avroByte, avroByte)
+ }
+}
+
class DefaultSourceSuite extends FunSuite with
BeforeAndAfterEach with BeforeAndAfterAll with Logging {
@transient var sc: SparkContext = null
@@ -836,4 +865,107 @@ BeforeAndAfterEach with BeforeAndAfterAll with Logging {
s.show
assert(s.count() == 6)
}
+
+ // catalog for insertion
+ def avroWriteCatalog = s"""{
+ |"table":{"namespace":"default", "name":"avrotable"},
+ |"rowkey":"key",
+ |"columns":{
+ |"col0":{"cf":"rowkey", "col":"key", "type":"binary"},
+ |"col1":{"cf":"cf1", "col":"col1", "type":"binary"}
+ |}
+ |}""".stripMargin
+
+ // catalog for read
+ def avroCatalog = s"""{
+ |"table":{"namespace":"default", "name":"avrotable"},
+ |"rowkey":"key",
+ |"columns":{
+ |"col0":{"cf":"rowkey", "col":"key", "avro":"avroSchema"},
+ |"col1":{"cf":"cf1", "col":"col1", "avro":"avroSchema"}
+ |}
+ |}""".stripMargin
+
+ // for insert to another table
+ def avroCatalogInsert = s"""{
+ |"table":{"namespace":"default", "name":"avrotableInsert"},
+ |"rowkey":"key",
+ |"columns":{
+ |"col0":{"cf":"rowkey", "col":"key", "avro":"avroSchema"},
+ |"col1":{"cf":"cf1", "col":"col1", "avro":"avroSchema"}
+ |}
+ |}""".stripMargin
+
+ def withAvroCatalog(cat: String): DataFrame = {
+ sqlContext
+ .read
+ .options(Map("avroSchema"->AvroHBaseKeyRecord.schemaString,
+ HBaseTableCatalog.tableCatalog->avroCatalog))
+ .format("org.apache.hadoop.hbase.spark")
+ .load()
+ }
+
+
+ test("populate avro table") {
+ val sql = sqlContext
+ import sql.implicits._
+
+ val data = (0 to 255).map { i =>
+ AvroHBaseKeyRecord(i)
+ }
+ sc.parallelize(data).toDF.write.options(
+ Map(HBaseTableCatalog.tableCatalog -> avroWriteCatalog,
+ HBaseTableCatalog.newTable -> "5"))
+ .format("org.apache.hadoop.hbase.spark")
+ .save()
+ }
+
+ test("avro empty column") {
+ val df = withAvroCatalog(avroCatalog)
+ df.registerTempTable("avrotable")
+ val c = sqlContext.sql("select count(1) from avrotable")
+ .rdd.collect()(0)(0).asInstanceOf[Long]
+ assert(c == 256)
+ }
+
+ test("avro full query") {
+ val df = withAvroCatalog(avroCatalog)
+ df.show
+ df.printSchema()
+ assert(df.count() == 256)
+ }
+
+ test("avro serialization and deserialization query") {
+ val df = withAvroCatalog(avroCatalog)
+ df.write.options(
+ Map("avroSchema"->AvroHBaseKeyRecord.schemaString,
+ HBaseTableCatalog.tableCatalog->avroCatalogInsert,
+ HBaseTableCatalog.newTable -> "5"))
+ .format("org.apache.hadoop.hbase.spark")
+ .save()
+ val newDF = withAvroCatalog(avroCatalogInsert)
+ newDF.show
+ newDF.printSchema()
+ assert(newDF.count() == 256)
+ }
+
+ test("avro filtered query") {
+ val sql = sqlContext
+ import sql.implicits._
+ val df = withAvroCatalog(avroCatalog)
+ val r = df.filter($"col1.name" === "name005" || $"col1.name" <= "name005")
+ .select("col0", "col1.favorite_color", "col1.favorite_number")
+ r.show
+ assert(r.count() == 6)
+ }
+
+ test("avro Or filter") {
+ val sql = sqlContext
+ import sql.implicits._
+ val df = withAvroCatalog(avroCatalog)
+ val s = df.filter($"col1.name" <= "name005" || $"col1.name".contains("name007"))
+ .select("col0", "col1.favorite_color", "col1.favorite_number")
+ s.show
+ assert(s.count() == 7)
+ }
}