diff --git a/hbase-spark/src/main/scala/org/apache/hadoop/hbase/spark/DefaultSource.scala b/hbase-spark/src/main/scala/org/apache/hadoop/hbase/spark/DefaultSource.scala index b6d7982..844b5b5 100644 --- a/hbase-spark/src/main/scala/org/apache/hadoop/hbase/spark/DefaultSource.scala +++ b/hbase-spark/src/main/scala/org/apache/hadoop/hbase/spark/DefaultSource.scala @@ -164,7 +164,7 @@ case class HBaseRelation (val tableName:String, HBaseSparkConf.BULKGET_SIZE, HBaseSparkConf.defaultBulkGetSize)) //create or get latest HBaseContext - @transient val hbaseContext:HBaseContext = if (useHBaseContext) { + val hbaseContext:HBaseContext = if (useHBaseContext) { LatestHBaseContextCache.latest } else { val config = HBaseConfiguration.create() @@ -270,7 +270,7 @@ case class HBaseRelation (val tableName:String, } else { None } - val hRdd = new HBaseTableScanRDD(this, pushDownFilterJava, requiredQualifierDefinitionList.seq) + val hRdd = new HBaseTableScanRDD(this, hbaseContext, pushDownFilterJava, requiredQualifierDefinitionList.seq) pushDownRowKeyFilter.points.foreach(hRdd.addPoint(_)) pushDownRowKeyFilter.ranges.foreach(hRdd.addRange(_)) var resultRDD: RDD[Row] = { diff --git a/hbase-spark/src/main/scala/org/apache/hadoop/hbase/spark/HBaseContext.scala b/hbase-spark/src/main/scala/org/apache/hadoop/hbase/spark/HBaseContext.scala index 2d21e69..61ed3cf 100644 --- a/hbase-spark/src/main/scala/org/apache/hadoop/hbase/spark/HBaseContext.scala +++ b/hbase-spark/src/main/scala/org/apache/hadoop/hbase/spark/HBaseContext.scala @@ -29,6 +29,7 @@ import org.apache.hadoop.hbase.io.encoding.DataBlockEncoding import org.apache.hadoop.hbase.io.hfile.{CacheConfig, HFileContextBuilder, HFileWriterImpl} import org.apache.hadoop.hbase.regionserver.{HStore, StoreFile, BloomType} import org.apache.hadoop.hbase.util.Bytes +import org.apache.hadoop.mapred.JobConf import org.apache.spark.broadcast.Broadcast import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.rdd.RDD @@ -228,7 +229,7 @@ class HBaseContext(@transient sc: SparkContext, })) } - def applyCreds[T] (configBroadcast: Broadcast[SerializableWritable[Configuration]]){ + def applyCreds[T] (){ credentials = SparkHadoopUtil.get.getCurrentUserCredentials() logDebug("appliedCredentials:" + appliedCredentials + ",credentials:" + credentials) @@ -440,10 +441,14 @@ class HBaseContext(@transient sc: SparkContext, TableMapReduceUtil.initTableMapperJob(tableName, scan, classOf[IdentityTableMapper], null, null, job) - sc.newAPIHadoopRDD(job.getConfiguration, + val jconf = new JobConf(job.getConfiguration) + SparkHadoopUtil.get.addCredentials(jconf) + new NewHBaseRDD(sc, classOf[TableInputFormat], classOf[ImmutableBytesWritable], - classOf[Result]).map(f) + classOf[Result], + job.getConfiguration, + this).map(f) } /** @@ -474,7 +479,7 @@ class HBaseContext(@transient sc: SparkContext, val config = getConf(configBroadcast) - applyCreds(configBroadcast) + applyCreds // specify that this is a proxy user val connection = ConnectionFactory.createConnection(config) f(it, connection) @@ -514,7 +519,7 @@ class HBaseContext(@transient sc: SparkContext, Iterator[U]): Iterator[U] = { val config = getConf(configBroadcast) - applyCreds(configBroadcast) + applyCreds val connection = ConnectionFactory.createConnection(config) val res = mp(it, connection) diff --git a/hbase-spark/src/main/scala/org/apache/hadoop/hbase/spark/NewHBaseRDD.scala b/hbase-spark/src/main/scala/org/apache/hadoop/hbase/spark/NewHBaseRDD.scala new file mode 100644 index 0000000..8e5e8f9 --- /dev/null +++ b/hbase-spark/src/main/scala/org/apache/hadoop/hbase/spark/NewHBaseRDD.scala @@ -0,0 +1,36 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.hadoop.hbase.spark + +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.mapreduce.InputFormat +import org.apache.spark.rdd.NewHadoopRDD +import org.apache.spark.{InterruptibleIterator, Partition, SparkContext, TaskContext} + +class NewHBaseRDD[K,V](@transient sc : SparkContext, + @transient inputFormatClass: Class[_ <: InputFormat[K, V]], + @transient keyClass: Class[K], + @transient valueClass: Class[V], + @transient conf: Configuration, + val hBaseContext: HBaseContext) extends NewHadoopRDD(sc,inputFormatClass, keyClass, valueClass, conf) { + + override def compute(theSplit: Partition, context: TaskContext): InterruptibleIterator[(K, V)] = { + hBaseContext.applyCreds() + super.compute(theSplit, context) + } +} diff --git a/hbase-spark/src/main/scala/org/apache/hadoop/hbase/spark/datasources/HBaseTableScanRDD.scala b/hbase-spark/src/main/scala/org/apache/hadoop/hbase/spark/datasources/HBaseTableScanRDD.scala index f288c34..d859957 100644 --- a/hbase-spark/src/main/scala/org/apache/hadoop/hbase/spark/datasources/HBaseTableScanRDD.scala +++ b/hbase-spark/src/main/scala/org/apache/hadoop/hbase/spark/datasources/HBaseTableScanRDD.scala @@ -20,7 +20,7 @@ package org.apache.hadoop.hbase.spark.datasources import java.util.ArrayList import org.apache.hadoop.hbase.client._ -import org.apache.hadoop.hbase.spark.{ScanRange, SchemaQualifierDefinition, HBaseRelation, SparkSQLPushDownFilter} +import org.apache.hadoop.hbase.spark._ import org.apache.hadoop.hbase.spark.hbase._ import org.apache.hadoop.hbase.spark.datasources.HBaseResources._ import org.apache.spark.{SparkEnv, TaskContext, Logging, Partition} @@ -28,10 +28,10 @@ import org.apache.spark.rdd.RDD import scala.collection.mutable - class HBaseTableScanRDD(relation: HBaseRelation, - @transient val filter: Option[SparkSQLPushDownFilter] = None, - val columns: Seq[SchemaQualifierDefinition] = Seq.empty + val hbaseContext: HBaseContext, + @transient val filter: Option[SparkSQLPushDownFilter] = None, + val columns: Seq[SchemaQualifierDefinition] = Seq.empty )extends RDD[Result](relation.sqlContext.sparkContext, Nil) with Logging { private def sparkConf = SparkEnv.get.conf @transient var ranges = Seq.empty[Range] @@ -98,7 +98,8 @@ class HBaseTableScanRDD(relation: HBaseRelation, tbr: TableResource, g: Seq[Array[Byte]], filter: Option[SparkSQLPushDownFilter], - columns: Seq[SchemaQualifierDefinition]): Iterator[Result] = { + columns: Seq[SchemaQualifierDefinition], + hbaseContext: HBaseContext): Iterator[Result] = { g.grouped(relation.bulkGetSize).flatMap{ x => val gets = new ArrayList[Get]() x.foreach{ y => @@ -111,6 +112,7 @@ class HBaseTableScanRDD(relation: HBaseRelation, filter.foreach(g.setFilter(_)) gets.add(g) } + hbaseContext.applyCreds() val tmp = tbr.get(gets) rddResources.addResource(tmp) toResultIterator(tmp) @@ -208,11 +210,12 @@ class HBaseTableScanRDD(relation: HBaseRelation, if (points.isEmpty) { Iterator.empty: Iterator[Result] } else { - buildGets(tableResource, points, filter, columns) + buildGets(tableResource, points, filter, columns, hbaseContext) } } val rIts = scans.par .map { scan => + hbaseContext.applyCreds() val scanner = tableResource.getScanner(scan) rddResources.addResource(scanner) scanner