diff --git hbase-spark/src/main/scala/org/apache/hadoop/hbase/spark/DefaultSource.scala hbase-spark/src/main/scala/org/apache/hadoop/hbase/spark/DefaultSource.scala index 1a3c370..8d94a79 100644 --- hbase-spark/src/main/scala/org/apache/hadoop/hbase/spark/DefaultSource.scala +++ hbase-spark/src/main/scala/org/apache/hadoop/hbase/spark/DefaultSource.scala @@ -155,9 +155,10 @@ case class HBaseRelation ( if (numReg > 3) { val tName = TableName.valueOf(catalog.name) val cfs = catalog.getColumnFamilies - val connection = ConnectionFactory.createConnection(hbaseConf) + + val connection = HBaseConnectionCache.getConnection(hbaseConf) // Initialize hBase table if necessary - val admin = connection.getAdmin() + val admin = connection.getAdmin try { if (!admin.isTableAvailable(tName)) { val tableDesc = new HTableDescriptor(tName) diff --git hbase-spark/src/main/scala/org/apache/hadoop/hbase/spark/HBaseConnectionCache.scala hbase-spark/src/main/scala/org/apache/hadoop/hbase/spark/HBaseConnectionCache.scala new file mode 100644 index 0000000..678e769 --- /dev/null +++ hbase-spark/src/main/scala/org/apache/hadoop/hbase/spark/HBaseConnectionCache.scala @@ -0,0 +1,243 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.hadoop.hbase.spark + +import java.io.IOException + +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.hbase.client.{Admin, Connection, ConnectionFactory, RegionLocator, Table} +import org.apache.hadoop.hbase.ipc.RpcControllerFactory +import org.apache.hadoop.hbase.security.{User, UserProvider} +import org.apache.hadoop.hbase.spark.datasources.HBaseSparkConf +import org.apache.hadoop.hbase.{HConstants, TableName} +import org.apache.spark.Logging + +import scala.collection.mutable + +private[spark] object HBaseConnectionCache extends Logging { + + // A hashmap of Spark-HBase connections. Key is HBaseConnectionKey. + val connectionMap = new mutable.HashMap[HBaseConnectionKey, SmartConnection]() + + // in milliseconds + private final val DEFAULT_TIME_OUT: Long = HBaseSparkConf.connectionCloseDelay + private var timeout = DEFAULT_TIME_OUT + private var closed: Boolean = false + + var housekeepingThread = new Thread(new Runnable { + override def run() { + while (true) { + try { + Thread.sleep(timeout) + } catch { + case e: InterruptedException => + // setTimeout() and close() may interrupt the sleep and it's safe + // to ignore the exception + } + if (closed) + return + performHousekeeping(false) + } + } + }) + housekeepingThread.setDaemon(true) + housekeepingThread.start() + + def close(): Unit = { + try { + connectionMap.synchronized { + if (closed) + return + closed = true + housekeepingThread.interrupt() + housekeepingThread = null + HBaseConnectionCache.performHousekeeping(true) + } + } catch { + case e: Exception => logWarning("Error in finalHouseKeeping", e) + } + } + + def performHousekeeping(forceClean: Boolean) = { + val tsNow: Long = System.currentTimeMillis() + connectionMap.synchronized { + connectionMap.foreach { + x => { + if(x._2.refCount < 0) { + logError(s"Bug to be fixed: negative refCount of connection ${x._2}") + } + + if(forceClean || ((x._2.refCount <= 0) && (tsNow - x._2.timestamp > timeout))) { + try{ + x._2.connection.close() + } catch { + case e: IOException => logWarning(s"Fail to close connection ${x._2}", e) + } + connectionMap.remove(x._1) + } + } + } + } + } + + // For testing purpose only + def getConnection(key: HBaseConnectionKey, conn: => Connection): SmartConnection = { + connectionMap.synchronized { + if (closed) + return null + val sc = connectionMap.getOrElseUpdate(key, new SmartConnection(conn)) + sc.refCount += 1 + sc + } + } + + def getConnection(conf: Configuration): SmartConnection = + getConnection(new HBaseConnectionKey(conf), ConnectionFactory.createConnection(conf)) + + // For testing purpose only + def setTimeout(to: Long): Unit = { + connectionMap.synchronized { + if (closed) + return + timeout = to + housekeepingThread.interrupt() + } + } +} + +private[hbase] case class SmartConnection ( + connection: Connection, var refCount: Int = 0, var timestamp: Long = 0) { + def getTable(tableName: TableName): Table = connection.getTable(tableName) + def getRegionLocator(tableName: TableName): RegionLocator = connection.getRegionLocator(tableName) + def isClosed: Boolean = connection.isClosed + def getAdmin: Admin = connection.getAdmin + def close() = { + HBaseConnectionCache.connectionMap.synchronized { + refCount -= 1 + if(refCount <= 0) + timestamp = System.currentTimeMillis() + } + } +} + +/** + * Denotes a unique key to an HBase Connection instance. + * Please refer to 'org.apache.hadoop.hbase.client.HConnectionKey'. + * + * In essence, this class captures the properties in Configuration + * that may be used in the process of establishing a connection. + * + */ +class HBaseConnectionKey(c: Configuration) extends Logging { + val conf: Configuration = c + val CONNECTION_PROPERTIES: Array[String] = Array[String]( + HConstants.ZOOKEEPER_QUORUM, + HConstants.ZOOKEEPER_ZNODE_PARENT, + HConstants.ZOOKEEPER_CLIENT_PORT, + HConstants.ZOOKEEPER_RECOVERABLE_WAITTIME, + HConstants.HBASE_CLIENT_PAUSE, + HConstants.HBASE_CLIENT_RETRIES_NUMBER, + HConstants.HBASE_RPC_TIMEOUT_KEY, + HConstants.HBASE_META_SCANNER_CACHING, + HConstants.HBASE_CLIENT_INSTANCE_ID, + HConstants.RPC_CODEC_CONF_KEY, + HConstants.USE_META_REPLICAS, + RpcControllerFactory.CUSTOM_CONTROLLER_CONF_KEY) + + var username: String = _ + var m_properties = mutable.HashMap.empty[String, String] + if (conf != null) { + for (property <- CONNECTION_PROPERTIES) { + val value: String = conf.get(property) + if (value != null) { + m_properties.+=((property, value)) + } + } + try { + val provider: UserProvider = UserProvider.instantiate(conf) + val currentUser: User = provider.getCurrent + if (currentUser != null) { + username = currentUser.getName + } + } + catch { + case e: IOException => { + logWarning("Error obtaining current user, skipping username in HBaseConnectionKey", e) + } + } + } + + // make 'properties' immutable + val properties = m_properties.toMap + + override def hashCode: Int = { + val prime: Int = 31 + var result: Int = 1 + if (username != null) { + result = username.hashCode + } + for (property <- CONNECTION_PROPERTIES) { + val value: Option[String] = properties.get(property) + if (value.isDefined) { + result = prime * result + value.hashCode + } + } + result + } + + override def equals(obj: Any): Boolean = { + if (obj == null) return false + if (getClass ne obj.getClass) return false + val that: HBaseConnectionKey = obj.asInstanceOf[HBaseConnectionKey] + if (this.username != null && !(this.username == that.username)) { + return false + } + else if (this.username == null && that.username != null) { + return false + } + if (this.properties == null) { + if (that.properties != null) { + return false + } + } + else { + if (that.properties == null) { + return false + } + var flag: Boolean = true + for (property <- CONNECTION_PROPERTIES) { + val thisValue: Option[String] = this.properties.get(property) + val thatValue: Option[String] = that.properties.get(property) + flag = true + if (thisValue eq thatValue) { + flag = false //continue, so make flag be false + } + if (flag && (thisValue == null || !(thisValue == thatValue))) { + return false + } + } + } + true + } + + override def toString: String = { + "HBaseConnectionKey{" + "properties=" + properties + ", username='" + username + '\'' + '}' + } +} + + diff --git hbase-spark/src/main/scala/org/apache/hadoop/hbase/spark/HBaseContext.scala hbase-spark/src/main/scala/org/apache/hadoop/hbase/spark/HBaseContext.scala index a9b38ba..aeffecb 100644 --- hbase-spark/src/main/scala/org/apache/hadoop/hbase/spark/HBaseContext.scala +++ hbase-spark/src/main/scala/org/apache/hadoop/hbase/spark/HBaseContext.scala @@ -482,9 +482,9 @@ class HBaseContext(@transient sc: SparkContext, applyCreds // specify that this is a proxy user - val connection = ConnectionFactory.createConnection(config) - f(it, connection) - connection.close() + val smartConn = HBaseConnectionCache.getConnection(config) + f(it, smartConn.connection) + smartConn.close() } private def getConf(configBroadcast: Broadcast[SerializableWritable[Configuration]]): @@ -522,11 +522,10 @@ class HBaseContext(@transient sc: SparkContext, val config = getConf(configBroadcast) applyCreds - val connection = ConnectionFactory.createConnection(config) - val res = mp(it, connection) - connection.close() + val smartConn = HBaseConnectionCache.getConnection(config) + val res = mp(it, smartConn.connection) + smartConn.close() res - } /** @@ -619,7 +618,7 @@ class HBaseContext(@transient sc: SparkContext, compactionExclude: Boolean = false, maxSize:Long = HConstants.DEFAULT_MAX_FILE_SIZE): Unit = { - val conn = ConnectionFactory.createConnection(config) + val conn = HBaseConnectionCache.getConnection(config) val regionLocator = conn.getRegionLocator(tableName) val startKeys = regionLocator.getStartKeys val defaultCompressionStr = config.get("hfile.compression", @@ -742,7 +741,7 @@ class HBaseContext(@transient sc: SparkContext, compactionExclude: Boolean = false, maxSize:Long = HConstants.DEFAULT_MAX_FILE_SIZE): Unit = { - val conn = ConnectionFactory.createConnection(config) + val conn = HBaseConnectionCache.getConnection(config) val regionLocator = conn.getRegionLocator(tableName) val startKeys = regionLocator.getStartKeys val defaultCompressionStr = config.get("hfile.compression", diff --git hbase-spark/src/main/scala/org/apache/hadoop/hbase/spark/datasources/HBaseResources.scala hbase-spark/src/main/scala/org/apache/hadoop/hbase/spark/datasources/HBaseResources.scala index 14c5fd0..7e0a862 100644 --- hbase-spark/src/main/scala/org/apache/hadoop/hbase/spark/datasources/HBaseResources.scala +++ hbase-spark/src/main/scala/org/apache/hadoop/hbase/spark/datasources/HBaseResources.scala @@ -19,7 +19,8 @@ package org.apache.hadoop.hbase.spark.datasources import org.apache.hadoop.hbase.TableName import org.apache.hadoop.hbase.client._ -import org.apache.hadoop.hbase.spark.HBaseRelation +import org.apache.hadoop.hbase.spark.{HBaseConnectionKey, SmartConnection, + HBaseConnectionCache, HBaseRelation} import scala.language.implicitConversions // Resource and ReferencedResources are defined for extensibility, @@ -84,11 +85,11 @@ trait ReferencedResource { } case class TableResource(relation: HBaseRelation) extends ReferencedResource { - var connection: Connection = _ + var connection: SmartConnection = _ var table: Table = _ override def init(): Unit = { - connection = ConnectionFactory.createConnection(relation.hbaseConf) + connection = HBaseConnectionCache.getConnection(relation.hbaseConf) table = connection.getTable(TableName.valueOf(relation.tableName)) } @@ -113,7 +114,7 @@ case class TableResource(relation: HBaseRelation) extends ReferencedResource { } case class RegionResource(relation: HBaseRelation) extends ReferencedResource { - var connection: Connection = _ + var connection: SmartConnection = _ var rl: RegionLocator = _ val regions = releaseOnException { val keys = rl.getStartEndKeys @@ -127,7 +128,7 @@ case class RegionResource(relation: HBaseRelation) extends ReferencedResource { } override def init(): Unit = { - connection = ConnectionFactory.createConnection(relation.hbaseConf) + connection = HBaseConnectionCache.getConnection(relation.hbaseConf) rl = connection.getRegionLocator(TableName.valueOf(relation.tableName)) } diff --git hbase-spark/src/main/scala/org/apache/hadoop/hbase/spark/datasources/HBaseSparkConf.scala hbase-spark/src/main/scala/org/apache/hadoop/hbase/spark/datasources/HBaseSparkConf.scala index d9ee494..bdef1a4 100644 --- hbase-spark/src/main/scala/org/apache/hadoop/hbase/spark/datasources/HBaseSparkConf.scala +++ hbase-spark/src/main/scala/org/apache/hadoop/hbase/spark/datasources/HBaseSparkConf.scala @@ -43,4 +43,7 @@ object HBaseSparkConf{ val MAX_VERSIONS = "hbase.spark.query.maxVersions" val ENCODER = "hbase.spark.query.encoder" val defaultEncoder = classOf[NaiveEncoder].getCanonicalName + + // in milliseconds + val connectionCloseDelay = 10 * 60 * 1000 } diff --git hbase-spark/src/main/scala/org/apache/hadoop/hbase/spark/datasources/HBaseTableScanRDD.scala hbase-spark/src/main/scala/org/apache/hadoop/hbase/spark/datasources/HBaseTableScanRDD.scala index 5b45ef9..7761acb 100644 --- hbase-spark/src/main/scala/org/apache/hadoop/hbase/spark/datasources/HBaseTableScanRDD.scala +++ hbase-spark/src/main/scala/org/apache/hadoop/hbase/spark/datasources/HBaseTableScanRDD.scala @@ -23,6 +23,7 @@ import org.apache.hadoop.hbase.client._ import org.apache.hadoop.hbase.spark._ import org.apache.hadoop.hbase.spark.hbase._ import org.apache.hadoop.hbase.spark.datasources.HBaseResources._ +import org.apache.hadoop.hbase.util.ShutdownHookManager import org.apache.spark.sql.datasources.hbase.Field import org.apache.spark.{SparkEnv, TaskContext, Logging, Partition} import org.apache.spark.rdd.RDD @@ -85,10 +86,14 @@ class HBaseTableScanRDD(relation: HBaseRelation, } }.toArray regions.release() + ShutdownHookManager.affixShutdownHook( new Thread() { + override def run() { + HBaseConnectionCache.close() + } + }, 0) ps.asInstanceOf[Array[Partition]] } - override def getPreferredLocations(split: Partition): Seq[String] = { split.asInstanceOf[HBaseScanPartition].regions.server.map { identity @@ -148,7 +153,6 @@ class HBaseTableScanRDD(relation: HBaseRelation, iterator } - private def buildScan(range: Range, filter: Option[SparkSQLPushDownFilter], columns: Seq[Field]): Scan = { @@ -226,6 +230,11 @@ class HBaseTableScanRDD(relation: HBaseRelation, .fold(Iterator.empty: Iterator[Result]){ case (x, y) => x ++ y } ++ gIt + ShutdownHookManager.affixShutdownHook( new Thread() { + override def run() { + HBaseConnectionCache.close() + } + }, 0) rIts } diff --git hbase-spark/src/test/scala/org/apache/hadoop/hbase/spark/HBaseConnectionCacheSuite.scala hbase-spark/src/test/scala/org/apache/hadoop/hbase/spark/HBaseConnectionCacheSuite.scala new file mode 100644 index 0000000..066d534 --- /dev/null +++ hbase-spark/src/test/scala/org/apache/hadoop/hbase/spark/HBaseConnectionCacheSuite.scala @@ -0,0 +1,196 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.hadoop.hbase.spark + +import java.util.concurrent.ExecutorService +import scala.util.Random + +import org.apache.hadoop.hbase.client.{BufferedMutator, Table, RegionLocator, + Connection, BufferedMutatorParams, Admin} +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.hbase.TableName +import org.apache.spark.Logging +import org.scalatest.FunSuite + +case class HBaseConnectionKeyMocker (confId: Int) extends HBaseConnectionKey (null) { + override def hashCode: Int = { + confId + } + + override def equals(obj: Any): Boolean = { + if(!obj.isInstanceOf[HBaseConnectionKeyMocker]) + false + else + confId == obj.asInstanceOf[HBaseConnectionKeyMocker].confId + } +} + +class ConnectionMocker extends Connection { + var isClosed: Boolean = false + + def getRegionLocator (tableName: TableName): RegionLocator = null + def getConfiguration: Configuration = null + def getTable (tableName: TableName): Table = null + def getTable(tableName: TableName, pool: ExecutorService): Table = null + def getBufferedMutator (params: BufferedMutatorParams): BufferedMutator = null + def getBufferedMutator (tableName: TableName): BufferedMutator = null + def getAdmin: Admin = null + + def close(): Unit = { + if (isClosed) + throw new IllegalStateException() + isClosed = true + } + + def isAborted: Boolean = true + def abort(why: String, e: Throwable) = {} +} + +class HBaseConnectionCacheSuite extends FunSuite with Logging { + /* + * These tests must be performed sequentially as they operate with an + * unique running thread and resource. + * + * It looks there's no way to tell FunSuite to do so, so making those + * test cases normal functions which are called sequentially in a single + * test case. + */ + test("all test cases") { + testBasic() + testWithPressureWithoutClose() + testWithPressureWithClose() + } + + def testBasic() { + HBaseConnectionCache.setTimeout(1 * 1000) + + val connKeyMocker1 = new HBaseConnectionKeyMocker(1) + val connKeyMocker1a = new HBaseConnectionKeyMocker(1) + val connKeyMocker2 = new HBaseConnectionKeyMocker(2) + + val c1 = HBaseConnectionCache + .getConnection(connKeyMocker1, new ConnectionMocker) + val c1a = HBaseConnectionCache + .getConnection(connKeyMocker1a, new ConnectionMocker) + + HBaseConnectionCache.connectionMap.synchronized { + assert(HBaseConnectionCache.connectionMap.size === 1) + } + + val c2 = HBaseConnectionCache + .getConnection(connKeyMocker2, new ConnectionMocker) + + HBaseConnectionCache.connectionMap.synchronized { + assert(HBaseConnectionCache.connectionMap.size === 2) + } + + c1.close() + HBaseConnectionCache.connectionMap.synchronized { + assert(HBaseConnectionCache.connectionMap.size === 2) + } + + c1a.close() + HBaseConnectionCache.connectionMap.synchronized { + assert(HBaseConnectionCache.connectionMap.size === 2) + } + + Thread.sleep(3 * 1000) // Leave housekeeping thread enough time + HBaseConnectionCache.connectionMap.synchronized { + assert(HBaseConnectionCache.connectionMap.size === 1) + assert(HBaseConnectionCache.connectionMap.iterator.next()._1 + .asInstanceOf[HBaseConnectionKeyMocker].confId === 2) + } + + c2.close() + } + + def testWithPressureWithoutClose() { + class TestThread extends Runnable { + override def run() { + for (i <- 0 to 999) { + val c = HBaseConnectionCache.getConnection( + new HBaseConnectionKeyMocker(Random.nextInt(10)), new ConnectionMocker) + } + } + } + + HBaseConnectionCache.setTimeout(500) + val threads: Array[Thread] = new Array[Thread](100) + for (i <- 0 to 99) { + threads.update(i, new Thread(new TestThread())) + threads(i).run() + } + try { + threads.foreach { x => x.join() } + } catch { + case e: InterruptedException => println(e.getMessage) + } + + Thread.sleep(1000) + HBaseConnectionCache.connectionMap.synchronized { + assert(HBaseConnectionCache.connectionMap.size === 10) + var totalRc : Int = 0 + HBaseConnectionCache.connectionMap.foreach { + x => totalRc += x._2.refCount + } + assert(totalRc === 100 * 1000) + HBaseConnectionCache.connectionMap.foreach { + x => { + x._2.refCount = 0 + x._2.timestamp = System.currentTimeMillis() - 1000 + } + } + } + Thread.sleep(1000) + assert(HBaseConnectionCache.connectionMap.size === 0) + } + + def testWithPressureWithClose() { + class TestThread extends Runnable { + override def run() { + for (i <- 0 to 999) { + val c = HBaseConnectionCache.getConnection( + new HBaseConnectionKeyMocker(Random.nextInt(10)), new ConnectionMocker) + Thread.`yield`() + c.close() + } + } + } + + HBaseConnectionCache.setTimeout(3 * 1000) + val threads: Array[Thread] = new Array[Thread](100) + for (i <- threads.indices) { + threads.update(i, new Thread(new TestThread())) + threads(i).run() + } + try { + threads.foreach { x => x.join() } + } catch { + case e: InterruptedException => println(e.getMessage) + } + + HBaseConnectionCache.connectionMap.synchronized { + assert(HBaseConnectionCache.connectionMap.size === 10) + } + + Thread.sleep(6 * 1000) + HBaseConnectionCache.connectionMap.synchronized { + assert(HBaseConnectionCache.connectionMap.size === 0) + } + } +}