diff --git hbase-client/src/main/java/org/apache/hadoop/hbase/client/AsyncProcess.java hbase-client/src/main/java/org/apache/hadoop/hbase/client/AsyncProcess.java index d699233..f63c947 100644 --- hbase-client/src/main/java/org/apache/hadoop/hbase/client/AsyncProcess.java +++ hbase-client/src/main/java/org/apache/hadoop/hbase/client/AsyncProcess.java @@ -1032,7 +1032,6 @@ class AsyncProcess { for (Map.Entry> e : actionsByServer.entrySet()) { ServerName server = e.getKey(); MultiAction multiAction = e.getValue(); - incTaskCounters(multiAction.getRegions(), server); Collection runnables = getNewMultiActionRunnable(server, multiAction, numAttempt); // make sure we correctly count the number of runnables before we try to reuse the send @@ -1080,6 +1079,7 @@ class AsyncProcess { if (connection.getConnectionMetrics() != null) { connection.getConnectionMetrics().incrNormalRunners(); } + incTaskCounters(multiAction.getRegions(), server); return Collections.singletonList(Trace.wrap("AsyncProcess.sendMultiAction", new SingleServerRequestRunnable(multiAction, numAttempt, server, callsInProgress))); } @@ -1101,6 +1101,7 @@ class AsyncProcess { List toReturn = new ArrayList(actions.size()); for (DelayingRunner runner : actions.values()) { + incTaskCounters(runner.getActions().getRegions(), server); String traceText = "AsyncProcess.sendMultiAction"; Runnable runnable = new SingleServerRequestRunnable(runner.getActions(), numAttempt, server, @@ -1723,7 +1724,8 @@ class AsyncProcess { } } - private void updateStats(ServerName server, Map results) { + @VisibleForTesting + protected void updateStats(ServerName server, Map results) { boolean metrics = AsyncProcess.this.connection.getConnectionMetrics() != null; boolean stats = AsyncProcess.this.connection.getStatisticsTracker() != null; if (!stats && !metrics) { diff --git hbase-client/src/test/java/org/apache/hadoop/hbase/client/TestAsyncProcess.java hbase-client/src/test/java/org/apache/hadoop/hbase/client/TestAsyncProcess.java index 5959078..d8e181d 100644 --- hbase-client/src/test/java/org/apache/hadoop/hbase/client/TestAsyncProcess.java +++ hbase-client/src/test/java/org/apache/hadoop/hbase/client/TestAsyncProcess.java @@ -57,6 +57,10 @@ import org.apache.hadoop.hbase.RegionLocations; import org.apache.hadoop.hbase.ServerName; import org.apache.hadoop.hbase.TableName; import org.apache.hadoop.hbase.client.AsyncProcess.AsyncRequestFuture; +import org.apache.hadoop.hbase.client.backoff.ClientBackoffPolicy; +import org.apache.hadoop.hbase.client.backoff.ClientBackoffPolicyFactory; +import org.apache.hadoop.hbase.client.backoff.ExponentialClientBackoffPolicy; +import org.apache.hadoop.hbase.client.backoff.ServerStatistics; import org.apache.hadoop.hbase.client.coprocessor.Batch; import org.apache.hadoop.hbase.client.coprocessor.Batch.Callback; import org.apache.hadoop.hbase.ipc.RpcControllerFactory; @@ -65,6 +69,7 @@ import org.apache.hadoop.hbase.testclassification.MediumTests; import org.apache.hadoop.hbase.util.Bytes; import org.apache.hadoop.hbase.util.Threads; import org.junit.Assert; +import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertTrue; import org.junit.BeforeClass; import org.junit.Rule; @@ -188,6 +193,11 @@ public class TestAsyncProcess { return super.submit(DUMMY_TABLE, rows, atLeastOne, callback, true); } + + @Override + protected void updateStats(ServerName server, Map results) { + // Do nothing for avoiding the NPE if we test the ClientBackofPolicy. + } @Override protected RpcRetryingCaller createCaller( CancellableRegionServerCallable callable) { @@ -258,7 +268,21 @@ public class TestAsyncProcess { return new CallerWithFailure(ioe); } } - + /** + * Make the backoff time always different on each call. + */ + static class MyClientBackoffPolicy implements ClientBackoffPolicy { + private final Map count = new HashMap<>(); + @Override + public long getBackoffTime(ServerName serverName, byte[] region, ServerStatistics stats) { + AtomicInteger inc = count.get(serverName); + if (inc == null) { + inc = new AtomicInteger(0); + count.put(serverName, inc); + } + return inc.getAndIncrement(); + } + } class MyAsyncProcessWithReplicas extends MyAsyncProcess { private Set failures = new TreeSet(new Bytes.ByteArrayComparator()); private long primarySleepMs = 0, replicaSleepMs = 0; @@ -618,6 +642,46 @@ public class TestAsyncProcess { } @Test + public void testTaskCountWithoutClientBackoffPolicy() throws IOException, InterruptedException { + ClusterConnection hc = createHConnection(); + MyAsyncProcess ap = new MyAsyncProcess(hc, conf, false); + testTaskCount(ap); + } + + @Test + public void testTaskCountWithClientBackoffPolicy() throws IOException, InterruptedException { + Configuration copyConf = new Configuration(conf); + copyConf.setBoolean(HConstants.ENABLE_CLIENT_BACKPRESSURE, true); + MyClientBackoffPolicy bp = new MyClientBackoffPolicy(); + ClusterConnection hc = createHConnection(); + Mockito.when(hc.getConfiguration()).thenReturn(copyConf); + Mockito.when(hc.getStatisticsTracker()).thenReturn(ServerStatisticTracker.create(copyConf)); + Mockito.when(hc.getBackoffPolicy()).thenReturn(bp); + MyAsyncProcess ap = new MyAsyncProcess(hc, copyConf, false); + testTaskCount(ap); + } + + private void testTaskCount(AsyncProcess ap) throws InterruptedIOException, InterruptedException { + List puts = new ArrayList<>(); + for (int i = 0; i != 3; ++i) { + puts.add(createPut(1, true)); + puts.add(createPut(2, true)); + puts.add(createPut(3, true)); + } + ap.submit(DUMMY_TABLE, puts, true, null, false); + ap.waitUntilDone(); + // More time to wait if there are incorrect task count. + TimeUnit.SECONDS.sleep(1); + assertEquals(0, ap.tasksInProgress.get()); + for (AtomicInteger count : ap.taskCounterPerRegion.values()) { + assertEquals(0, count.get()); + } + for (AtomicInteger count : ap.taskCounterPerServer.values()) { + assertEquals(0, count.get()); + } + } + + @Test public void testMaxTask() throws Exception { final AsyncProcess ap = new MyAsyncProcess(createHConnection(), conf, false);