diff --git a/spark/hbase-spark/src/main/scala/org/apache/hadoop/hbase/spark/HBaseContext.scala b/spark/hbase-spark/src/main/scala/org/apache/hadoop/hbase/spark/HBaseContext.scala index e50a3e8..890e67f 100644 --- a/spark/hbase-spark/src/main/scala/org/apache/hadoop/hbase/spark/HBaseContext.scala +++ b/spark/hbase-spark/src/main/scala/org/apache/hadoop/hbase/spark/HBaseContext.scala @@ -65,13 +65,11 @@ class HBaseContext(@transient val sc: SparkContext, val tmpHdfsConfgFile: String = null) extends Serializable with Logging { - @transient var credentials = UserGroupInformation.getCurrentUser().getCredentials() @transient var tmpHdfsConfiguration:Configuration = config @transient var appliedCredentials = false @transient val job = Job.getInstance(config) TableMapReduceUtil.initCredentials(job) val broadcastedConf = sc.broadcast(new SerializableWritable(config)) - val credentialsConf = sc.broadcast(new SerializableWritable(job.getCredentials)) LatestHBaseContextCache.latest = this @@ -233,21 +231,12 @@ class HBaseContext(@transient val sc: SparkContext, } def applyCreds[T] (){ - credentials = UserGroupInformation.getCurrentUser().getCredentials() - - if (log.isDebugEnabled) { - logDebug("appliedCredentials:" + appliedCredentials + ",credentials:" + credentials) - } - - if (!appliedCredentials && credentials != null) { + if (!appliedCredentials) { appliedCredentials = true @transient val ugi = UserGroupInformation.getCurrentUser - ugi.addCredentials(credentials) // specify that this is a proxy user ugi.setAuthenticationMethod(AuthenticationMethod.PROXY) - - ugi.addCredentials(credentialsConf.value.value) } } diff --git a/spark/hbase-spark/src/test/java/org/apache/hadoop/hbase/spark/TestJavaHBaseContext.java b/spark/hbase-spark/src/test/java/org/apache/hadoop/hbase/spark/TestJavaHBaseContext.java index 4134ee6..4df64da 100644 --- a/spark/hbase-spark/src/test/java/org/apache/hadoop/hbase/spark/TestJavaHBaseContext.java +++ b/spark/hbase-spark/src/test/java/org/apache/hadoop/hbase/spark/TestJavaHBaseContext.java @@ -52,8 +52,10 @@ import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.api.java.function.Function; import org.junit.After; +import org.junit.AfterClass; import org.junit.Assert; import org.junit.Before; +import org.junit.BeforeClass; import org.junit.ClassRule; import org.junit.Test; import org.junit.experimental.categories.Category; @@ -70,11 +72,10 @@ public class TestJavaHBaseContext implements Serializable { public static final HBaseClassTestRule TIMEOUT = HBaseClassTestRule.forClass(TestJavaHBaseContext.class); - private transient JavaSparkContext jsc; - HBaseTestingUtility htu; - protected static final Logger LOG = LoggerFactory.getLogger(TestJavaHBaseContext.class); - - + private static transient JavaSparkContext JSC; + private static HBaseTestingUtility HTU; + private static JavaHBaseContext HBASE_CONTEXT; + private static final Logger LOG = LoggerFactory.getLogger(TestJavaHBaseContext.class); byte[] tableName = Bytes.toBytes("t1"); byte[] columnFamily = Bytes.toBytes("c"); @@ -82,56 +83,59 @@ public class TestJavaHBaseContext implements Serializable { String columnFamilyStr = Bytes.toString(columnFamily); String columnFamilyStr1 = Bytes.toString(columnFamily1); - - @Before - public void setUp() { - jsc = new JavaSparkContext("local", "JavaHBaseContextSuite"); - + @BeforeClass + public static void setUpBeforeClass() throws Exception { File tempDir = Files.createTempDir(); tempDir.deleteOnExit(); - htu = new HBaseTestingUtility(); - try { - LOG.info("cleaning up test dir"); + JSC = new JavaSparkContext("local", "JavaHBaseContextSuite"); + HTU = new HBaseTestingUtility(); + Configuration conf = HTU.getConfiguration(); - htu.cleanupTestDir(); + HBASE_CONTEXT = new JavaHBaseContext(JSC, conf); - LOG.info("starting minicluster"); + LOG.info("cleaning up test dir"); - htu.startMiniZKCluster(); - htu.startMiniHBaseCluster(1, 1); + HTU.cleanupTestDir(); - LOG.info(" - minicluster started"); + LOG.info("starting minicluster"); - try { - htu.deleteTable(TableName.valueOf(tableName)); - } catch (Exception e) { - LOG.info(" - no table " + Bytes.toString(tableName) + " found"); - } + HTU.startMiniZKCluster(); + HTU.startMiniHBaseCluster(1, 1); - LOG.info(" - creating table " + Bytes.toString(tableName)); - htu.createTable(TableName.valueOf(tableName), - new byte[][]{columnFamily, columnFamily1}); - LOG.info(" - created table"); - } catch (Exception e1) { - throw new RuntimeException(e1); - } + LOG.info(" - minicluster started"); } - @After - public void tearDown() { + @AfterClass + public static void tearDownAfterClass() throws Exception { + LOG.info("shuting down minicluster"); + HTU.shutdownMiniHBaseCluster(); + HTU.shutdownMiniZKCluster(); + LOG.info(" - minicluster shut down"); + HTU.cleanupTestDir(); + + JSC.stop(); + JSC = null; + } + + @Before + public void setUp() throws Exception { + try { - htu.deleteTable(TableName.valueOf(tableName)); - LOG.info("shuting down minicluster"); - htu.shutdownMiniHBaseCluster(); - htu.shutdownMiniZKCluster(); - LOG.info(" - minicluster shut down"); - htu.cleanupTestDir(); + HTU.deleteTable(TableName.valueOf(tableName)); } catch (Exception e) { - throw new RuntimeException(e); + LOG.info(" - no table " + Bytes.toString(tableName) + " found"); } - jsc.stop(); - jsc = null; + + LOG.info(" - creating table " + Bytes.toString(tableName)); + HTU.createTable(TableName.valueOf(tableName), + new byte[][]{columnFamily, columnFamily1}); + LOG.info(" - created table"); + } + + @After + public void tearDown() throws Exception { + HTU.deleteTable(TableName.valueOf(tableName)); } @Test @@ -144,11 +148,9 @@ public class TestJavaHBaseContext implements Serializable { list.add("4," + columnFamilyStr + ",a,4"); list.add("5," + columnFamilyStr + ",a,5"); - JavaRDD rdd = jsc.parallelize(list); + JavaRDD rdd = JSC.parallelize(list); - Configuration conf = htu.getConfiguration(); - - JavaHBaseContext hbaseContext = new JavaHBaseContext(jsc, conf); + Configuration conf = HTU.getConfiguration(); Connection conn = ConnectionFactory.createConnection(conf); Table table = conn.getTable(TableName.valueOf(tableName)); @@ -163,7 +165,7 @@ public class TestJavaHBaseContext implements Serializable { table.close(); } - hbaseContext.bulkPut(rdd, + HBASE_CONTEXT.bulkPut(rdd, TableName.valueOf(tableName), new PutFunction()); @@ -212,15 +214,13 @@ public class TestJavaHBaseContext implements Serializable { list.add(Bytes.toBytes("2")); list.add(Bytes.toBytes("3")); - JavaRDD rdd = jsc.parallelize(list); + JavaRDD rdd = JSC.parallelize(list); - Configuration conf = htu.getConfiguration(); + Configuration conf = HTU.getConfiguration(); populateTableWithMockData(conf, TableName.valueOf(tableName)); - JavaHBaseContext hbaseContext = new JavaHBaseContext(jsc, conf); - - hbaseContext.bulkDelete(rdd, TableName.valueOf(tableName), + HBASE_CONTEXT.bulkDelete(rdd, TableName.valueOf(tableName), new JavaHBaseBulkDeleteExample.DeleteFunction(), 2); @@ -248,17 +248,15 @@ public class TestJavaHBaseContext implements Serializable { @Test public void testDistributedScan() throws IOException { - Configuration conf = htu.getConfiguration(); + Configuration conf = HTU.getConfiguration(); populateTableWithMockData(conf, TableName.valueOf(tableName)); - JavaHBaseContext hbaseContext = new JavaHBaseContext(jsc, conf); - Scan scan = new Scan(); scan.setCaching(100); JavaRDD javaRdd = - hbaseContext.hbaseRDD(TableName.valueOf(tableName), scan) + HBASE_CONTEXT.hbaseRDD(TableName.valueOf(tableName), scan) .map(new ScanConvertFunction()); List results = javaRdd.collect(); @@ -283,16 +281,14 @@ public class TestJavaHBaseContext implements Serializable { list.add(Bytes.toBytes("4")); list.add(Bytes.toBytes("5")); - JavaRDD rdd = jsc.parallelize(list); + JavaRDD rdd = JSC.parallelize(list); - Configuration conf = htu.getConfiguration(); + Configuration conf = HTU.getConfiguration(); populateTableWithMockData(conf, TableName.valueOf(tableName)); - JavaHBaseContext hbaseContext = new JavaHBaseContext(jsc, conf); - final JavaRDD stringJavaRDD = - hbaseContext.bulkGet(TableName.valueOf(tableName), 2, rdd, + HBASE_CONTEXT.bulkGet(TableName.valueOf(tableName), 2, rdd, new GetFunction(), new ResultFunction()); @@ -302,7 +298,7 @@ public class TestJavaHBaseContext implements Serializable { @Test public void testBulkLoad() throws Exception { - Path output = htu.getDataTestDir("testBulkLoad"); + Path output = HTU.getDataTestDir("testBulkLoad"); // Add cell as String: "row,falmily,qualifier,value" List list= new ArrayList(); // row1 @@ -315,14 +311,11 @@ public class TestJavaHBaseContext implements Serializable { list.add("2," + columnFamilyStr + ",a,3"); list.add("2," + columnFamilyStr + ",b,3"); - JavaRDD rdd = jsc.parallelize(list); - - Configuration conf = htu.getConfiguration(); - JavaHBaseContext hbaseContext = new JavaHBaseContext(jsc, conf); - + JavaRDD rdd = JSC.parallelize(list); + Configuration conf = HTU.getConfiguration(); - hbaseContext.bulkLoad(rdd, TableName.valueOf(tableName), new BulkLoadFunction(), + HBASE_CONTEXT.bulkLoad(rdd, TableName.valueOf(tableName), new BulkLoadFunction(), output.toUri().getPath(), new HashMap(), false, HConstants.DEFAULT_MAX_FILE_SIZE); @@ -369,7 +362,7 @@ public class TestJavaHBaseContext implements Serializable { @Test public void testBulkLoadThinRows() throws Exception { - Path output = htu.getDataTestDir("testBulkLoadThinRows"); + Path output = HTU.getDataTestDir("testBulkLoadThinRows"); // because of the limitation of scala bulkLoadThinRows API // we need to provide data as List> list= new ArrayList>(); @@ -389,12 +382,11 @@ public class TestJavaHBaseContext implements Serializable { list2.add("2," + columnFamilyStr + ",b,3"); list.add(list2); - JavaRDD> rdd = jsc.parallelize(list); + JavaRDD> rdd = JSC.parallelize(list); - Configuration conf = htu.getConfiguration(); - JavaHBaseContext hbaseContext = new JavaHBaseContext(jsc, conf); + Configuration conf = HTU.getConfiguration(); - hbaseContext.bulkLoadThinRows(rdd, TableName.valueOf(tableName), new BulkLoadThinRowsFunction(), + HBASE_CONTEXT.bulkLoadThinRows(rdd, TableName.valueOf(tableName), new BulkLoadThinRowsFunction(), output.toString(), new HashMap(), false, HConstants.DEFAULT_MAX_FILE_SIZE); diff --git a/spark/hbase-spark/src/test/scala/org/apache/hadoop/hbase/spark/HBaseContextSuite.scala b/spark/hbase-spark/src/test/scala/org/apache/hadoop/hbase/spark/HBaseContextSuite.scala index 83e2ac6..1b35b93 100644 --- a/spark/hbase-spark/src/test/scala/org/apache/hadoop/hbase/spark/HBaseContextSuite.scala +++ b/spark/hbase-spark/src/test/scala/org/apache/hadoop/hbase/spark/HBaseContextSuite.scala @@ -27,6 +27,7 @@ class HBaseContextSuite extends FunSuite with BeforeAndAfterEach with BeforeAndAfterAll with Logging { @transient var sc: SparkContext = null + var hbaseContext: HBaseContext = null var TEST_UTIL = new HBaseTestingUtility val tableName = "t1" @@ -49,6 +50,9 @@ BeforeAndAfterEach with BeforeAndAfterAll with Logging { val envMap = Map[String,String](("Xmx", "512m")) sc = new SparkContext("local", "test", null, Nil, envMap) + + val config = TEST_UTIL.getConfiguration + hbaseContext = new HBaseContext(sc, config) } override def afterAll() { @@ -73,7 +77,6 @@ BeforeAndAfterEach with BeforeAndAfterAll with Logging { (Bytes.toBytes("5"), Array((Bytes.toBytes(columnFamily), Bytes.toBytes("e"), Bytes.toBytes("bar")))))) - val hbaseContext = new HBaseContext(sc, config) hbaseContext.bulkPut[(Array[Byte], Array[(Array[Byte], Array[Byte], Array[Byte])])](rdd, TableName.valueOf(tableName), (putRecord) => { @@ -132,7 +135,6 @@ BeforeAndAfterEach with BeforeAndAfterAll with Logging { Bytes.toBytes("delete1"), Bytes.toBytes("delete3"))) - val hbaseContext = new HBaseContext(sc, config) hbaseContext.bulkDelete[Array[Byte]](rdd, TableName.valueOf(tableName), putRecord => new Delete(putRecord), @@ -174,7 +176,6 @@ BeforeAndAfterEach with BeforeAndAfterAll with Logging { Bytes.toBytes("get2"), Bytes.toBytes("get3"), Bytes.toBytes("get4"))) - val hbaseContext = new HBaseContext(sc, config) val getRdd = hbaseContext.bulkGet[Array[Byte], String]( TableName.valueOf(tableName), @@ -221,7 +222,6 @@ BeforeAndAfterEach with BeforeAndAfterAll with Logging { Bytes.toBytes("get2"), Bytes.toBytes("get3"), Bytes.toBytes("get4"))) - val hbaseContext = new HBaseContext(sc, config) intercept[SparkException] { try { @@ -274,7 +274,6 @@ BeforeAndAfterEach with BeforeAndAfterAll with Logging { Bytes.toBytes("get2"), Bytes.toBytes("get3"), Bytes.toBytes("get4"))) - val hbaseContext = new HBaseContext(sc, config) val getRdd = hbaseContext.bulkGet[Array[Byte], String]( TableName.valueOf(tableName), @@ -329,8 +328,6 @@ BeforeAndAfterEach with BeforeAndAfterAll with Logging { connection.close() } - val hbaseContext = new HBaseContext(sc, config) - val scan = new Scan() val filter = new FirstKeyOnlyFilter() scan.setCaching(100)