diff --git a/hbase-server/src/main/java/org/apache/hadoop/hbase/regionserver/SecureBulkLoadManager.java b/hbase-server/src/main/java/org/apache/hadoop/hbase/regionserver/SecureBulkLoadManager.java index a4ee517fd6..f1f8b939e5 100644 --- a/hbase-server/src/main/java/org/apache/hadoop/hbase/regionserver/SecureBulkLoadManager.java +++ b/hbase-server/src/main/java/org/apache/hadoop/hbase/regionserver/SecureBulkLoadManager.java @@ -25,6 +25,8 @@ import java.util.ArrayList; import java.util.HashMap; import java.util.List; import java.util.Map; +import java.util.concurrent.ConcurrentHashMap; +import java.util.function.BiFunction; import org.apache.hadoop.conf.Configuration; import org.apache.hadoop.fs.FileStatus; @@ -50,6 +52,7 @@ import org.apache.hadoop.hbase.util.Pair; import org.apache.hadoop.io.Text; import org.apache.hadoop.security.UserGroupInformation; import org.apache.hadoop.security.token.Token; +import org.apache.hbase.thirdparty.com.google.common.annotations.VisibleForTesting; import org.apache.yetus.audience.InterfaceAudience; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -106,6 +109,7 @@ public class SecureBulkLoadManager { private Path baseStagingDir; private UserProvider userProvider; + ConcurrentHashMap ugiReferenceCounter; private Connection conn; SecureBulkLoadManager(Configuration conf, Connection conn) { @@ -116,6 +120,7 @@ public class SecureBulkLoadManager { public void start() throws IOException { random = new SecureRandom(); userProvider = UserProvider.instantiate(conf); + ugiReferenceCounter = new ConcurrentHashMap(); fs = FileSystem.get(conf); baseStagingDir = new Path(FSUtils.getRootDir(conf), HConstants.BULKLOAD_STAGING_DIR_NAME); @@ -158,7 +163,7 @@ public class SecureBulkLoadManager { } finally { UserGroupInformation ugi = getActiveUser().getUGI(); try { - if (!UserGroupInformation.getLoginUser().equals(ugi)) { + if (!UserGroupInformation.getLoginUser().equals(ugi) && !isUserReferenced(ugi) ) { FileSystem.closeAllForUGI(ugi); } } catch (IOException e) { @@ -167,6 +172,40 @@ public class SecureBulkLoadManager { } } + @VisibleForTesting + interface InternalObserverForTest { + void afterFileSystemCreated(HRegion r); + } + InternalObserverForTest testObserver = null; + + @VisibleForTesting + void setTestObserver(InternalObserverForTest testObserver) { + this.testObserver = testObserver; + } + + + void incrementUgiReference(UserGroupInformation ugi) { + ugiReferenceCounter.merge(ugi, 1, new BiFunction() { + @Override + public Integer apply(Integer integer, Integer integer2) { + return ++integer; + } + }); + } + + void decrementUgiReference(UserGroupInformation ugi) { + ugiReferenceCounter.computeIfPresent(ugi, new BiFunction() { + @Override + public Integer apply(UserGroupInformation userGroupInformation, Integer integer) { + return integer > 1 ? --integer : null; + } + }); + } + + boolean isUserReferenced(UserGroupInformation ugi) { + return ugiReferenceCounter.containsKey( ugi ) && ugiReferenceCounter.get(ugi) > 0; + } + public Map> secureBulkLoadHFiles(final HRegion region, final BulkLoadHFileRequest request) throws IOException { final List> familyPaths = new ArrayList<>(request.getFamilyPathCount()); @@ -208,6 +247,7 @@ public class SecureBulkLoadManager { Map> map = null; try { + incrementUgiReference(ugi); // Get the target fs (HBase region server fs) delegation token // Since we have checked the permission via 'preBulkLoadHFile', now let's give // the 'request user' necessary token to operate on the target fs. @@ -237,6 +277,10 @@ public class SecureBulkLoadManager { fs.setPermission(stageFamily, PERM_ALL_ACCESS); } } + /// just for test + if( testObserver != null ) { + testObserver.afterFileSystemCreated(region); + } //We call bulkLoadHFiles as requesting user //To enable access prior to staging return region.bulkLoadHFiles(familyPaths, true, @@ -251,6 +295,7 @@ public class SecureBulkLoadManager { if (region.getCoprocessorHost() != null) { region.getCoprocessorHost().postBulkLoadHFile(familyPaths, map); } + decrementUgiReference(ugi); } return map; } diff --git a/hbase-server/src/test/java/org/apache/hadoop/hbase/regionserver/TestSecureBulkloadManager.java b/hbase-server/src/test/java/org/apache/hadoop/hbase/regionserver/TestSecureBulkloadManager.java new file mode 100644 index 0000000000..327820d4f3 --- /dev/null +++ b/hbase-server/src/test/java/org/apache/hadoop/hbase/regionserver/TestSecureBulkloadManager.java @@ -0,0 +1,189 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.hadoop.hbase.regionserver; + +import java.io.IOException; +import java.nio.ByteBuffer; +import java.util.Deque; +import java.util.Map; +import java.util.concurrent.ExecutorService; + +import org.apache.hadoop.conf.Configuration; +import org.apache.hadoop.fs.Path; +import org.apache.hadoop.hbase.*; +import org.apache.hadoop.hbase.client.Connection; +import org.apache.hadoop.hbase.client.Get; +import org.apache.hadoop.hbase.client.Put; +import org.apache.hadoop.hbase.client.Result; +import org.apache.hadoop.hbase.client.Table; +import org.apache.hadoop.hbase.io.compress.Compression; +import org.apache.hadoop.hbase.io.crypto.Encryption; +import org.apache.hadoop.hbase.io.hfile.CacheConfig; +import org.apache.hadoop.hbase.io.hfile.HFile; +import org.apache.hadoop.hbase.io.hfile.HFileContext; +import org.apache.hadoop.hbase.io.hfile.HFileContextBuilder; +import org.apache.hadoop.hbase.regionserver.SecureBulkLoadManager.InternalObserverForTest; +import org.apache.hadoop.hbase.testclassification.RegionServerTests; +import org.apache.hadoop.hbase.testclassification.SmallTests; +import org.apache.hadoop.hbase.tool.LoadIncrementalHFiles; +import org.apache.hadoop.hbase.util.Bytes; +import org.apache.hadoop.hbase.util.EnvironmentEdgeManager; +import org.apache.hbase.thirdparty.com.google.common.collect.Multimap; +import org.junit.*; +import org.junit.experimental.categories.Category; + +@Category({RegionServerTests.class, SmallTests.class}) +public class TestSecureBulkloadManager implements InternalObserverForTest { + + @ClassRule + public static final HBaseClassTestRule CLASS_RULE = + HBaseClassTestRule.forClass(TestSecureBulkloadManager.class); + + TableName TABLE = TableName.valueOf( Bytes.toBytes("testcompact") ); + private static byte [] FAMILY = Bytes.toBytes("family"); + private static byte [] COLUMN = Bytes.toBytes("column"); + private static byte[] key1 = Bytes.toBytes("row1"); + private static byte[] key2 = Bytes.toBytes("row2"); + private static byte[] key3 = Bytes.toBytes("row3"); + private static byte[] value1 = Bytes.toBytes("t1"); + private static byte[] value3 = Bytes.toBytes("t3"); + private static byte [] SPLIT_ROWKEY = key2; + + protected final static HBaseTestingUtility testUtil = new HBaseTestingUtility(); + private static Configuration conf = testUtil.getConfiguration(); + + @BeforeClass + public static void setUp() throws Exception { + testUtil.startMiniCluster(); + } + + @AfterClass + public static void tearDown() throws Exception { + testUtil.shutdownMiniCluster(); + testUtil.cleanupTestDir(); + } + + @Test + public void testForRaceCondition() throws Exception { + testUtil.getMiniHBaseCluster().getRegionServerThreads().get(0).getRegionServer().secureBulkLoadManager.setTestObserver(this); + /// create table + HTableDescriptor htd2 = new HTableDescriptor(TABLE); + htd2.addFamily(new HColumnDescriptor(FAMILY)); + testUtil.getHBaseAdmin().createTable(htd2,Bytes.toByteArrays(SPLIT_ROWKEY)); + + /// prepare files + Path rootdir = testUtil.getMiniHBaseCluster().getRegionServerThreads().get(0).getRegionServer().getRootDir(); + Path dir1 = new Path(rootdir,"dir1"); + prepareHFile(dir1,key1,value1); + Path dir2 = new Path(rootdir,"dir2"); + prepareHFile(dir2,key3,value3); + + /// do bulkload + Thread t1 = new Thread(new Runnable() { + @Override public void run() { + doBulkloadWithoutRetry(dir1); + } + }); + Thread t2 = new Thread(new Runnable() { + @Override public void run() { + doBulkloadWithoutRetry(dir2); + } + }); + t1.start(); + t2.start(); + t1.join(); + t2.join(); + + /// check bulkload ok + Get get1 = new Get(key1); + Get get3 = new Get(key3); + Table t = testUtil.getConnection().getTable(TABLE); + Result r = t.get(get1); + Assert.assertArrayEquals( r.getValue(FAMILY,COLUMN),value1 ); + r = t.get(get3); + Assert.assertArrayEquals( r.getValue(FAMILY,COLUMN),value3 ); + + } + + class MyExceptionToAvoidRetry extends DoNotRetryIOException{}; + + void doBulkloadWithoutRetry(Path dir) { + try { + Connection connection = testUtil.getConnection(); + LoadIncrementalHFiles h = new LoadIncrementalHFiles(conf) { + @Override + protected void bulkLoadPhase(final Table htable, final Connection conn, + ExecutorService pool, Deque queue, + final Multimap regionGroups, boolean copyFile, + Map item2RegionMap) throws IOException { + super.bulkLoadPhase(htable, conn, pool, queue, regionGroups, copyFile, item2RegionMap); + throw new MyExceptionToAvoidRetry(); // throw exception to avoid retry + } + }; + try { + h.doBulkLoad(dir, testUtil.getAdmin(), connection.getTable(TABLE), connection.getRegionLocator(TABLE)); + } catch (MyExceptionToAvoidRetry e) {} + } catch (Exception e) { + e.printStackTrace(); + } + } + + void prepareHFile(Path dir,byte[] key, byte[] value) throws Exception { + HTableDescriptor desc = testUtil.getAdmin().getTableDescriptor(TABLE); + HColumnDescriptor family = desc.getFamily(FAMILY); + Compression.Algorithm compression = HFile.DEFAULT_COMPRESSION_ALGORITHM; + + CacheConfig writerCacheConf = new CacheConfig(conf,family); + writerCacheConf.setCacheDataOnWrite(false); + HFileContext hFileContext = new HFileContextBuilder() + .withIncludesMvcc(false) + .withIncludesTags(true) + .withCompression(compression) + .withCompressTags(family.isCompressTags()) + .withChecksumType(HStore.getChecksumType(conf)) + .withBytesPerCheckSum(HStore.getBytesPerChecksum(conf)) + .withBlockSize(family.getBlocksize()) + .withHBaseCheckSum(true) + .withDataBlockEncoding(family.getDataBlockEncoding()) + .withEncryptionContext(Encryption.Context.NONE) + .withCreateTime(EnvironmentEdgeManager.currentTime()) + .build(); + StoreFileWriter.Builder builder = new StoreFileWriter.Builder(conf, writerCacheConf, dir.getFileSystem(conf)) + .withOutputDir(new Path(dir,family.getNameAsString())) + .withBloomType(family.getBloomFilterType()) + .withMaxKeyCount(Integer.MAX_VALUE) + .withFileContext(hFileContext); + StoreFileWriter writer = builder.build(); + + Put put = new Put(key); + put.addColumn(FAMILY,COLUMN,value); + for (Cell c : put.get(FAMILY, COLUMN)) { + writer.append(c); + } + + writer.close(); + } + + @Override public void afterFileSystemCreated(HRegion r) { + if( r.getRegionInfo().containsRow(key3)) { + try { + Thread.sleep(3000); /// sleep 3s so the fs will be closed by the time we wakeup + } catch (InterruptedException e) {} + } + } +}