diff --git a/java/kudu-spark/src/main/scala/org/kududb/spark/KuduContext.scala b/java/kudu-spark/src/main/scala/org/kududb/spark/KuduContext.scala index a034099..e2468df 100644 --- a/java/kudu-spark/src/main/scala/org/kududb/spark/KuduContext.scala +++ b/java/kudu-spark/src/main/scala/org/kududb/spark/KuduContext.scala @@ -21,10 +21,13 @@ import org.apache.hadoop.conf.Configuration import org.apache.hadoop.io.NullWritable import org.apache.spark.rdd.RDD import org.apache.spark.SparkContext +import org.apache.spark.streaming.dstream.DStream import org.kududb.annotations.InterfaceStability import org.kududb.client.{AsyncKuduClient, KuduClient, RowResult} import org.kududb.mapreduce.KuduTableInputFormat +import scala.reflect.ClassTag + /** * KuduContext is a façade for Kudu operations. * @@ -33,10 +36,12 @@ import org.kududb.mapreduce.KuduTableInputFormat * share connections among the tasks in a JVM. */ @InterfaceStability.Unstable -class KuduContext(kuduMaster: String) extends Serializable { +class KuduContext(@transient sc: SparkContext, + @transient kuduMaster: String) extends Serializable { @transient lazy val syncClient = new KuduClient.KuduClientBuilder(kuduMaster).build() @transient lazy val asyncClient = new AsyncKuduClient.AsyncKuduClientBuilder(kuduMaster).build() + val broadcastedKuduMaster = sc.broadcast(kuduMaster) /** * Create an RDD from a Kudu table. * @@ -47,8 +52,7 @@ class KuduContext(kuduMaster: String) extends Serializable { * '*' means to project all columns. * @return a new RDD that maps over the given table for the selected columns */ - def kuduRDD(sc: SparkContext, - tableName: String, + def kuduRDD(tableName: String, columnProjection: Seq[String] = Nil): RDD[RowResult] = { val conf = new Configuration @@ -64,4 +68,170 @@ class KuduContext(kuduMaster: String) extends Serializable { val columnNames = if (columnProjection.nonEmpty) columnProjection.mkString(", ") else "(*)" rdd.values.setName(s"KuduRDD { table=$tableName, columnProjection=$columnNames }") } + + /** + * A simple enrichment of the traditional Spark RDD foreachPartition. + * This function differs from the original in that it offers the + * developer access to a already connected HConnection object + * + * Note: Do not close the HConnection object. All HConnection + * management is handled outside this method + * + * @param rdd Original RDD with data to iterate over + * @param f Function to be given a iterator to iterate through + * the RDD values and a HConnection object to interact + * with HBase + */ + def foreachPartition[T](rdd: RDD[T], + f: (Iterator[T], KuduClient, AsyncKuduClient) => Unit):Unit = { + rdd.foreachPartition( + it => kuduForeachPartition(it, f)) + } + + /** + * A simple enrichment of the traditional Spark Streaming dStream foreach + * This function differs from the original in that it offers the + * developer access to a already connected HConnection object + * + * Note: Do not close the HConnection object. All HConnection + * management is handled outside this method + * + * @param dstream Original DStream with data to iterate over + * @param f Function to be given a iterator to iterate through + * the DStream values and a HConnection object to + * interact with HBase + */ + def foreachPartition[T](dstream: DStream[T], + f: (Iterator[T], KuduClient, AsyncKuduClient) => Unit):Unit = { + dstream.foreachRDD((rdd, time) => { + foreachPartition(rdd, f) + }) + } + + /** + * A simple enrichment of the traditional Spark RDD mapPartition. + * This function differs from the original in that it offers the + * developer access to a already connected HConnection object + * + * Note: Do not close the HConnection object. All HConnection + * management is handled outside this method + * + * @param rdd Original RDD with data to iterate over + * @param mp Function to be given a iterator to iterate through + * the RDD values and a HConnection object to interact + * with HBase + * @return Returns a new RDD generated by the user definition + * function just like normal mapPartition + */ + def mapPartitions[T, R: ClassTag](rdd: RDD[T], + mp: (Iterator[T], KuduClient, AsyncKuduClient) => Iterator[R]): RDD[R] = { + + rdd.mapPartitions[R](it => kuduMapPartition[T, R](it, mp)) + + } + + /** + * A simple enrichment of the traditional Spark Streaming DStream + * foreachPartition. + * + * This function differs from the original in that it offers the + * developer access to a already connected HConnection object + * + * Note: Do not close the HConnection object. All HConnection + * management is handled outside this method + * + * Note: Make sure to partition correctly to avoid memory issue when + * getting data from HBase + * + * @param dstream Original DStream with data to iterate over + * @param f Function to be given a iterator to iterate through + * the DStream values and a HConnection object to + * interact with HBase + * @return Returns a new DStream generated by the user + * definition function just like normal mapPartition + */ + def streamForeachPartition[T](dstream: DStream[T], + f: (Iterator[T], KuduClient, AsyncKuduClient) => Unit): Unit = { + + dstream.foreachRDD(rdd => this.foreachPartition(rdd, f)) + } + + /** + * A simple enrichment of the traditional Spark Streaming DStream + * mapPartition. + * + * This function differs from the original in that it offers the + * developer access to a already connected HConnection object + * + * Note: Do not close the HConnection object. All HConnection + * management is handled outside this method + * + * Note: Make sure to partition correctly to avoid memory issue when + * getting data from HBase + * + * @param dstream Original DStream with data to iterate over + * @param f Function to be given a iterator to iterate through + * the DStream values and a HConnection object to + * interact with HBase + * @return Returns a new DStream generated by the user + * definition function just like normal mapPartition + */ + def streamMapPartitions[T, U: ClassTag](dstream: DStream[T], + f: (Iterator[T], KuduClient, AsyncKuduClient) => Iterator[U]): + DStream[U] = { + dstream.mapPartitions(it => kuduMapPartition[T, U]( + it, + f)) + } + + + /** + * underlining wrapper all foreach functions in HBaseContext + */ + private def kuduForeachPartition[T](it: Iterator[T], + f: (Iterator[T], KuduClient, AsyncKuduClient) => Unit) = { + f(it, KuduClientCache.getKuduClient(broadcastedKuduMaster.value), + KuduClientCache.getAsyncKuduClient(broadcastedKuduMaster.value)) + } + + /** + * underlining wrapper all mapPartition functions in HBaseContext + * + */ + private def kuduMapPartition[K, U](it: Iterator[K], + mp: (Iterator[K], KuduClient, AsyncKuduClient) => + Iterator[U]): Iterator[U] = { + + + val res = mp(it, + KuduClientCache.getKuduClient(broadcastedKuduMaster.value), + KuduClientCache.getAsyncKuduClient(broadcastedKuduMaster.value)) + + res + + } +} + +object KuduClientCache { + var kuduClient: KuduClient = null + var asyncKuduClient: AsyncKuduClient = null + + def getKuduClient(kuduMaster: String): KuduClient = { + this.synchronized { + if (kuduClient == null) { + kuduClient = new KuduClient.KuduClientBuilder(kuduMaster).build() + } + } + kuduClient + } + + def getAsyncKuduClient(kuduMaster: String): AsyncKuduClient = { + this.synchronized { + if (asyncKuduClient == null) { + asyncKuduClient = new AsyncKuduClient.AsyncKuduClientBuilder(kuduMaster).build() + } + } + asyncKuduClient + } + } diff --git a/java/kudu-spark/src/test/scala/org/kududb/spark/KuduContextTest.scala b/java/kudu-spark/src/test/scala/org/kududb/spark/KuduContextTest.scala index 67aad7b..a3e3ac1 100644 --- a/java/kudu-spark/src/test/scala/org/kududb/spark/KuduContextTest.scala +++ b/java/kudu-spark/src/test/scala/org/kududb/spark/KuduContextTest.scala @@ -27,7 +27,7 @@ class KuduContextTest extends FunSuite with TestContext { insertRows(rowCount) - val scanRdd = kuduContext.kuduRDD(sc, "test") + val scanRdd = kuduContext.kuduRDD("test") val scanList = scanRdd.map(r => r.getInt(0)).collect() assert(scanList.length == rowCount) diff --git a/java/kudu-spark/src/test/scala/org/kududb/spark/TestContext.scala b/java/kudu-spark/src/test/scala/org/kududb/spark/TestContext.scala index 9876282..a4f6f33 100644 --- a/java/kudu-spark/src/test/scala/org/kududb/spark/TestContext.scala +++ b/java/kudu-spark/src/test/scala/org/kududb/spark/TestContext.scala @@ -55,7 +55,7 @@ trait TestContext extends BeforeAndAfterAll { self: Suite => kuduClient = new KuduClientBuilder(miniCluster.getMasterAddresses).build() assert(miniCluster.waitForTabletServers(1)) - kuduContext = new KuduContext(miniCluster.getMasterAddresses) + kuduContext = new KuduContext(sc, miniCluster.getMasterAddresses) val tableOptions = new CreateTableOptions().setNumReplicas(1) table = kuduClient.createTable(tableName, schema, tableOptions)