From 6202a7a0b3793328d2e2eaf091d7fcb926f70891 Mon Sep 17 00:00:00 2001 From: Jurriaan Mous Date: Tue, 6 Jan 2015 13:15:35 +0100 Subject: [PATCH] HBASE-12684 Add new AsyncRpcClient Signed-off-by: stack --- .../apache/hadoop/hbase/ipc/AbstractRpcClient.java | 67 +- .../org/apache/hadoop/hbase/ipc/AsyncCall.java | 162 +++++ .../apache/hadoop/hbase/ipc/AsyncRpcChannel.java | 768 +++++++++++++++++++++ .../apache/hadoop/hbase/ipc/AsyncRpcClient.java | 385 +++++++++++ .../hbase/ipc/AsyncServerResponseHandler.java | 130 ++++ .../org/apache/hadoop/hbase/ipc/ConnectionId.java | 15 +- .../hbase/ipc/PayloadCarryingRpcController.java | 2 +- .../org/apache/hadoop/hbase/ipc/RpcClient.java | 10 +- .../apache/hadoop/hbase/ipc/RpcClientFactory.java | 3 +- .../org/apache/hadoop/hbase/ipc/RpcClientImpl.java | 38 +- .../hadoop/hbase/ipc/TimeLimitedRpcController.java | 8 +- .../hadoop/hbase/security/HBaseSaslRpcClient.java | 2 +- .../hadoop/hbase/security/SaslClientHandler.java | 353 ++++++++++ .../java/org/apache/hadoop/hbase/HConstants.java | 2 +- .../org/apache/hadoop/hbase/ipc/RpcServer.java | 5 +- .../java/org/apache/hadoop/hbase/ipc/TestIPC.java | 295 +++++++- .../hadoop/hbase/ipc/TestRpcHandlerException.java | 11 +- 17 files changed, 2158 insertions(+), 98 deletions(-) create mode 100644 hbase-client/src/main/java/org/apache/hadoop/hbase/ipc/AsyncCall.java create mode 100644 hbase-client/src/main/java/org/apache/hadoop/hbase/ipc/AsyncRpcChannel.java create mode 100644 hbase-client/src/main/java/org/apache/hadoop/hbase/ipc/AsyncRpcClient.java create mode 100644 hbase-client/src/main/java/org/apache/hadoop/hbase/ipc/AsyncServerResponseHandler.java create mode 100644 hbase-client/src/main/java/org/apache/hadoop/hbase/security/SaslClientHandler.java diff --git a/hbase-client/src/main/java/org/apache/hadoop/hbase/ipc/AbstractRpcClient.java b/hbase-client/src/main/java/org/apache/hadoop/hbase/ipc/AbstractRpcClient.java index df43f6f..c3d2624 100644 --- a/hbase-client/src/main/java/org/apache/hadoop/hbase/ipc/AbstractRpcClient.java +++ b/hbase-client/src/main/java/org/apache/hadoop/hbase/ipc/AbstractRpcClient.java @@ -187,40 +187,32 @@ public abstract class AbstractRpcClient implements RpcClient { return config.getInt(HConstants.HBASE_CLIENT_IPC_POOL_SIZE, 1); } - /** * Make a blocking call. Throws exceptions if there are network problems or if the remote code * threw an exception. + * * @param ticket Be careful which ticket you pass. A new user will mean a new Connection. - * {@link UserProvider#getCurrent()} makes a new instance of User each time so will be a - * new Connection each time. + * {@link UserProvider#getCurrent()} makes a new instance of User each time so + * will be a + * new Connection each time. * @return A pair with the Message response and the Cell data (if any). */ Message callBlockingMethod(Descriptors.MethodDescriptor md, PayloadCarryingRpcController pcrc, Message param, Message returnType, final User ticket, final InetSocketAddress isa) throws ServiceException { + if (pcrc == null) { + pcrc = new PayloadCarryingRpcController(); + } + long startTime = 0; if (LOG.isTraceEnabled()) { startTime = EnvironmentEdgeManager.currentTime(); } - int callTimeout = 0; - CellScanner cells = null; - if (pcrc != null) { - callTimeout = pcrc.getCallTimeout(); - cells = pcrc.cellScanner(); - // Clear it here so we don't by mistake try and these cells processing results. - pcrc.setCellScanner(null); - } Pair val; try { - val = call(pcrc, md, param, cells, returnType, ticket, isa, callTimeout, - pcrc != null? pcrc.getPriority(): HConstants.NORMAL_QOS); - if (pcrc != null) { - // Shove the results into controller so can be carried across the proxy/pb service void. - if (val.getSecond() != null) pcrc.setCellScanner(val.getSecond()); - } else if (val.getSecond() != null) { - throw new ServiceException("Client dropping data on the floor!"); - } + val = call(pcrc, md, param, returnType, ticket, isa); + // Shove the results into controller so can be carried across the proxy/pb service void. + pcrc.setCellScanner(val.getSecond()); if (LOG.isTraceEnabled()) { long callTime = EnvironmentEdgeManager.currentTime() - startTime; @@ -238,26 +230,22 @@ public abstract class AbstractRpcClient implements RpcClient { * with the ticket credentials, returning the value. * Throws exceptions if there are network problems or if the remote code * threw an exception. + * * @param ticket Be careful which ticket you pass. A new user will mean a new Connection. - * {@link UserProvider#getCurrent()} makes a new instance of User each time so will be a - * new Connection each time. + * {@link UserProvider#getCurrent()} makes a new instance of User each time so + * will be a + * new Connection each time. * @return A pair with the Message response and the Cell data (if any). * @throws InterruptedException * @throws java.io.IOException */ protected abstract Pair call(PayloadCarryingRpcController pcrc, - Descriptors.MethodDescriptor md, Message param, CellScanner cells, - Message returnType, User ticket, InetSocketAddress addr, int callTimeout, int priority) throws - IOException, InterruptedException; + Descriptors.MethodDescriptor md, Message param, Message returnType, User ticket, + InetSocketAddress isa) throws IOException, InterruptedException; - /** - * Creates a "channel" that can be used by a blocking protobuf service. Useful setting up - * protobuf blocking stubs. - * @return A blocking rpc channel that goes via this rpc client instance. - */ @Override - public BlockingRpcChannel createBlockingRpcChannel(final ServerName sn, - final User ticket, int defaultOperationTimeout) { + public BlockingRpcChannel createBlockingRpcChannel(final ServerName sn, final User ticket, + int defaultOperationTimeout) { return new BlockingRpcChannelImplementation(this, sn, ticket, defaultOperationTimeout); } @@ -269,18 +257,17 @@ public abstract class AbstractRpcClient implements RpcClient { private final InetSocketAddress isa; private final AbstractRpcClient rpcClient; private final User ticket; - private final int defaultOperationTimeout; + private final int channelOperationTimeout; /** - * @param defaultOperationTimeout - the default timeout when no timeout is given - * by the caller. + * @param channelOperationTimeout - the default timeout when no timeout is given */ protected BlockingRpcChannelImplementation(final AbstractRpcClient rpcClient, - final ServerName sn, final User ticket, int defaultOperationTimeout) { + final ServerName sn, final User ticket, int channelOperationTimeout) { this.isa = new InetSocketAddress(sn.getHostname(), sn.getPort()); this.rpcClient = rpcClient; this.ticket = ticket; - this.defaultOperationTimeout = defaultOperationTimeout; + this.channelOperationTimeout = channelOperationTimeout; } @Override @@ -289,12 +276,12 @@ public abstract class AbstractRpcClient implements RpcClient { PayloadCarryingRpcController pcrc; if (controller != null) { pcrc = (PayloadCarryingRpcController) controller; - if (!pcrc.hasCallTimeout()){ - pcrc.setCallTimeout(defaultOperationTimeout); + if (!pcrc.hasCallTimeout()) { + pcrc.setCallTimeout(channelOperationTimeout); } } else { - pcrc = new PayloadCarryingRpcController(); - pcrc.setCallTimeout(defaultOperationTimeout); + pcrc = new PayloadCarryingRpcController(); + pcrc.setCallTimeout(channelOperationTimeout); } return this.rpcClient.callBlockingMethod(md, pcrc, param, returnType, this.ticket, this.isa); diff --git a/hbase-client/src/main/java/org/apache/hadoop/hbase/ipc/AsyncCall.java b/hbase-client/src/main/java/org/apache/hadoop/hbase/ipc/AsyncCall.java new file mode 100644 index 0000000..7dfb00b --- /dev/null +++ b/hbase-client/src/main/java/org/apache/hadoop/hbase/ipc/AsyncCall.java @@ -0,0 +1,162 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.hadoop.hbase.ipc; + +import com.google.protobuf.Descriptors; +import com.google.protobuf.Message; +import io.netty.channel.EventLoop; +import io.netty.util.concurrent.DefaultPromise; +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; +import org.apache.hadoop.hbase.CellScanner; +import org.apache.hadoop.hbase.classification.InterfaceAudience; +import org.apache.hadoop.hbase.protobuf.ProtobufUtil; +import org.apache.hadoop.hbase.util.EnvironmentEdgeManager; +import org.apache.hadoop.hbase.util.ExceptionUtil; +import org.apache.hadoop.ipc.RemoteException; + +import java.io.IOException; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.TimeoutException; + +/** + * Represents an Async Hbase call and its response. + * + * Responses are passed on to its given doneHandler and failures to the rpcController + */ +@InterfaceAudience.Private +public class AsyncCall extends DefaultPromise { + public static final Log LOG = LogFactory.getLog(AsyncCall.class.getName()); + + final int id; + + final Descriptors.MethodDescriptor method; + final Message param; + final PayloadCarryingRpcController controller; + final Message responseDefaultType; + final long startTime; + final long rpcTimeout; + + /** + * Constructor + * + * @param eventLoop for call + * @param connectId connection id + * @param md the method descriptor + * @param param parameters to send to Server + * @param controller controller for response + * @param responseDefaultType the default response type + */ + public AsyncCall(EventLoop eventLoop, int connectId, Descriptors.MethodDescriptor md, Message + param, PayloadCarryingRpcController controller, Message responseDefaultType) { + super(eventLoop); + + this.id = connectId; + + this.method = md; + this.param = param; + this.controller = controller; + this.responseDefaultType = responseDefaultType; + + this.startTime = EnvironmentEdgeManager.currentTime(); + this.rpcTimeout = controller.getCallTimeout(); + } + + /** + * Get the start time + * + * @return start time for the call + */ + public long getStartTime() { + return this.startTime; + } + + @Override public String toString() { + return "callId: " + this.id + " methodName: " + this.method.getName() + " param {" + + (this.param != null ? ProtobufUtil.getShortTextFormat(this.param) : "") + "}"; + } + + /** + * Set success with a cellBlockScanner + * + * @param value to set + * @param cellBlockScanner to set + */ + public void setSuccess(Message value, CellScanner cellBlockScanner) { + if (cellBlockScanner != null) { + controller.setCellScanner(cellBlockScanner); + } + + if (LOG.isTraceEnabled()) { + long callTime = EnvironmentEdgeManager.currentTime() - startTime; + LOG.trace("Call: " + method.getName() + ", callTime: " + callTime + "ms"); + } + + this.setSuccess(value); + } + + /** + * Set failed + * + * @param exception to set + */ + public void setFailed(IOException exception) { + if (ExceptionUtil.isInterrupt(exception)) { + exception = ExceptionUtil.asInterrupt(exception); + } + if (exception instanceof RemoteException) { + exception = ((RemoteException) exception).unwrapRemoteException(); + } + + this.setFailure(exception); + } + + /** + * Get the rpc timeout + * + * @return current timeout for this call + */ + public long getRpcTimeout() { + return rpcTimeout; + } + + @Override + public Message get() throws InterruptedException, ExecutionException { + try { + return super.get(Math.min(this.remainingTime(), 1000) + 1, TimeUnit.MILLISECONDS); + } catch (TimeoutException e) { + this.cancel(true); + throw new ExecutionException(e); + } + } + + /** + * Get the remaining time of the call + * + * @return remaining time in milliseconds + */ + public long remainingTime() { + if (rpcTimeout == 0) { + return Integer.MAX_VALUE; + } + + long remaining = rpcTimeout - (int) (EnvironmentEdgeManager.currentTime() - getStartTime()); + return remaining > 0 ? remaining : 0; + } +} \ No newline at end of file diff --git a/hbase-client/src/main/java/org/apache/hadoop/hbase/ipc/AsyncRpcChannel.java b/hbase-client/src/main/java/org/apache/hadoop/hbase/ipc/AsyncRpcChannel.java new file mode 100644 index 0000000..0f15b0b --- /dev/null +++ b/hbase-client/src/main/java/org/apache/hadoop/hbase/ipc/AsyncRpcChannel.java @@ -0,0 +1,768 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.hadoop.hbase.ipc; + +import com.google.protobuf.Descriptors; +import com.google.protobuf.Message; +import com.google.protobuf.RpcCallback; +import io.netty.bootstrap.Bootstrap; +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufOutputStream; +import io.netty.channel.Channel; +import io.netty.channel.ChannelFuture; +import io.netty.channel.ChannelFutureListener; +import io.netty.handler.codec.LengthFieldBasedFrameDecoder; +import io.netty.util.Timeout; +import io.netty.util.TimerTask; +import io.netty.util.concurrent.GenericFutureListener; +import io.netty.util.concurrent.Promise; +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; +import org.apache.hadoop.hbase.HConstants; +import org.apache.hadoop.hbase.classification.InterfaceAudience; +import org.apache.hadoop.hbase.exceptions.ConnectionClosingException; +import org.apache.hadoop.hbase.protobuf.generated.AuthenticationProtos; +import org.apache.hadoop.hbase.protobuf.generated.RPCProtos; +import org.apache.hadoop.hbase.protobuf.generated.TracingProtos; +import org.apache.hadoop.hbase.security.AuthMethod; +import org.apache.hadoop.hbase.security.SaslClientHandler; +import org.apache.hadoop.hbase.security.SaslUtil; +import org.apache.hadoop.hbase.security.SecurityInfo; +import org.apache.hadoop.hbase.security.User; +import org.apache.hadoop.hbase.security.token.AuthenticationTokenSelector; +import org.apache.hadoop.hbase.util.EnvironmentEdgeManager; +import org.apache.hadoop.io.Text; +import org.apache.hadoop.ipc.RemoteException; +import org.apache.hadoop.security.SecurityUtil; +import org.apache.hadoop.security.UserGroupInformation; +import org.apache.hadoop.security.token.Token; +import org.apache.hadoop.security.token.TokenIdentifier; +import org.apache.hadoop.security.token.TokenSelector; +import org.htrace.Span; +import org.htrace.Trace; + +import javax.security.sasl.SaslException; +import java.io.IOException; +import java.net.ConnectException; +import java.net.InetSocketAddress; +import java.net.SocketException; +import java.nio.ByteBuffer; +import java.security.PrivilegedExceptionAction; +import java.util.HashMap; +import java.util.Map; +import java.util.Random; +import java.util.concurrent.ConcurrentSkipListMap; +import java.util.concurrent.TimeUnit; + +/** + * Netty RPC channel + */ +@InterfaceAudience.Private +public class AsyncRpcChannel { + public static final Log LOG = LogFactory.getLog(AsyncRpcChannel.class.getName()); + + private static final int MAX_SASL_RETRIES = 5; + + protected final static Map> tokenHandlers = new HashMap<>(); + + static { + tokenHandlers.put(AuthenticationProtos.TokenIdentifier.Kind.HBASE_AUTH_TOKEN, + new AuthenticationTokenSelector()); + } + + final AsyncRpcClient client; + + // Contains the channel to work with. + // Only exists when connected + private Channel channel; + + String name; + final User ticket; + final String serviceName; + final InetSocketAddress address; + + ConcurrentSkipListMap calls = new ConcurrentSkipListMap<>(); + + private int ioFailureCounter = 0; + private int connectFailureCounter = 0; + + boolean useSasl; + AuthMethod authMethod; + private int reloginMaxBackoff; + private Token token; + private String serverPrincipal; + + boolean shouldCloseConnection = false; + private IOException closeException; + + private Timeout cleanupTimer; + + private final TimerTask timeoutTask = new TimerTask() { + @Override public void run(Timeout timeout) throws Exception { + cleanupTimer = null; + cleanupCalls(false); + } + }; + + /** + * Constructor for netty RPC channel + * + * @param bootstrap to construct channel on + * @param client to connect with + * @param ticket of user which uses connection + * @param serviceName name of service to connect to + * @param address to connect to + */ + public AsyncRpcChannel(Bootstrap bootstrap, final AsyncRpcClient client, User ticket, String + serviceName, InetSocketAddress address) { + this.client = client; + + this.ticket = ticket; + this.serviceName = serviceName; + this.address = address; + + this.channel = connect(bootstrap).channel(); + + name = ("IPC Client (" + channel.hashCode() + ") connection to " + + address.toString() + + ((ticket == null) ? + " from an unknown user" : + (" from " + ticket.getName()))); + } + + /** + * Connect to channel + * + * @param bootstrap to connect to + * @return future of connection + */ + private ChannelFuture connect(final Bootstrap bootstrap) { + return bootstrap.remoteAddress(address).connect() + .addListener(new GenericFutureListener() { + @Override + public void operationComplete(final ChannelFuture f) throws Exception { + if (!f.isSuccess()) { + if (f.cause() instanceof SocketException) { + retryOrClose(bootstrap, connectFailureCounter++, f.cause()); + } else { + retryOrClose(bootstrap, ioFailureCounter++, f.cause()); + } + return; + } + channel = f.channel(); + + setupAuthorization(); + + ByteBuf b = channel.alloc().heapBuffer(6); + createPreamble(b, authMethod); + channel.write(b).addListener(ChannelFutureListener.CLOSE_ON_FAILURE); + + if (useSasl) { + UserGroupInformation ticket = AsyncRpcChannel.this.ticket.getUGI(); + if (authMethod == AuthMethod.KERBEROS) { + if (ticket != null && ticket.getRealUser() != null) { + ticket = ticket.getRealUser(); + } + } + SaslClientHandler saslHandler; + if (ticket == null) { + throw new FatalConnectionException("ticket/user is null"); + } + saslHandler = ticket.doAs(new PrivilegedExceptionAction() { + @Override + public SaslClientHandler run() throws IOException { + return getSaslHandler(bootstrap); + } + }); + if (saslHandler != null) { + // Sasl connect is successful. Let's set up Sasl channel handler + channel.pipeline().addFirst(saslHandler); + } else { + // fall back to simple auth because server told us so. + authMethod = AuthMethod.SIMPLE; + useSasl = false; + } + } else { + startHBaseConnection(f.channel()); + } + } + }); + } + + /** + * Start HBase connection + * + * @param ch channel to start connection on + */ + private void startHBaseConnection(Channel ch) { + ch.pipeline() + .addLast("frameDecoder", new LengthFieldBasedFrameDecoder(Integer.MAX_VALUE, 0, 4, 0, 4)); + ch.pipeline().addLast(new AsyncServerResponseHandler(this)); + + try { + writeChannelHeader(ch).addListener(new GenericFutureListener() { + @Override public void operationComplete(ChannelFuture future) throws Exception { + if (!future.isSuccess()) { + close(future.cause()); + return; + } + for (AsyncCall call : calls.values()) { + writeRequest(call); + } + } + }); + } catch (IOException e) { + close(e); + } + } + + /** + * Get SASL handler + * + * @param bootstrap to reconnect to + * @return new SASL handler + * @throws java.io.IOException if handler failed to create + */ + private SaslClientHandler getSaslHandler(final Bootstrap bootstrap) throws IOException { + return new SaslClientHandler(authMethod, token, serverPrincipal, client.fallbackAllowed, + client.conf.get("hbase.rpc.protection", + SaslUtil.QualityOfProtection.AUTHENTICATION.name().toLowerCase()), + new SaslClientHandler.SaslExceptionHandler() { + @Override public void handle(int retryCount, Random random, Throwable cause) { + try { + // Handle Sasl failure. Try to potentially get new credentials + handleSaslConnectionFailure(retryCount, cause, ticket.getUGI()); + + // Try to reconnect + AsyncRpcClient.WHEEL_TIMER.newTimeout(new TimerTask() { + @Override public void run(Timeout timeout) throws Exception { + connect(bootstrap); + } + }, random.nextInt(reloginMaxBackoff) + 1, TimeUnit.MILLISECONDS); + } catch (IOException | InterruptedException e) { + close(e); + } + } + }, new SaslClientHandler.SaslSuccessfulConnectHandler() { + @Override public void onSuccess(Channel channel) { + startHBaseConnection(channel); + } + }); + } + + /** + * Retry to connect or close + * + * @param bootstrap to connect with + * @param connectCounter amount of tries + * @param e exception of fail + */ + private void retryOrClose(final Bootstrap bootstrap, int connectCounter, Throwable e) { + if (connectCounter < client.maxRetries) { + AsyncRpcClient.WHEEL_TIMER.newTimeout(new TimerTask() { + @Override public void run(Timeout timeout) throws Exception { + connect(bootstrap); + } + }, client.failureSleep, TimeUnit.MILLISECONDS); + } else { + client.failedServers.addToFailedServers(address); + close(e); + } + } + + /** + * Calls method on channel + * @param method to call + * @param controller to run call with + * @param request to send + * @param responsePrototype to construct response with + */ + public Promise callMethod(final Descriptors.MethodDescriptor method, + final PayloadCarryingRpcController controller, final Message request, + final Message responsePrototype) { + if (shouldCloseConnection) { + Promise promise = channel.eventLoop().newPromise(); + promise.setFailure(new ConnectException()); + return promise; + } + + final AsyncCall call = new AsyncCall(channel.eventLoop(), client.callIdCnt.getAndIncrement(), + method, request, controller, responsePrototype); + + controller.notifyOnCancel(new RpcCallback() { + @Override + public void run(Object parameter) { + failCall(call, new IOException("Canceled connection")); + } + }); + + calls.put(call.id, call); + + // Add timeout for cleanup if none is present + if (cleanupTimer == null) { + cleanupTimer = AsyncRpcClient.WHEEL_TIMER.newTimeout(timeoutTask, call.getRpcTimeout(), + TimeUnit.MILLISECONDS); + } + + if(channel.isActive()) { + writeRequest(call); + } + + return call; + } + + /** + * Calls method and returns a promise + * @param method to call + * @param controller to run call with + * @param request to send + * @param responsePrototype for response message + * @return Promise to listen to result + * @throws java.net.ConnectException on connection failures + */ + public Promise callMethodWithPromise( + final Descriptors.MethodDescriptor method, final PayloadCarryingRpcController controller, + final Message request, final Message responsePrototype) throws ConnectException { + if (shouldCloseConnection || !channel.isOpen()) { + throw new ConnectException(); + } + + return this.callMethod(method, controller, request, responsePrototype); + } + + /** + * Write the channel header + * + * @param channel to write to + * @return future of write + * @throws java.io.IOException on failure to write + */ + private ChannelFuture writeChannelHeader(Channel channel) throws IOException { + RPCProtos.ConnectionHeader.Builder headerBuilder = + RPCProtos.ConnectionHeader.newBuilder().setServiceName(serviceName); + + RPCProtos.ConnectionHeader.Builder builder = RPCProtos.ConnectionHeader.newBuilder(); + builder.setServiceName(serviceName); + RPCProtos.UserInformation userInfoPB = buildUserInfo(ticket.getUGI(), authMethod); + if (userInfoPB != null) { + headerBuilder.setUserInfo(userInfoPB); + } + + if (client.codec != null) { + headerBuilder.setCellBlockCodecClass(client.codec.getClass().getCanonicalName()); + } + if (client.compressor != null) { + headerBuilder.setCellBlockCompressorClass(client.compressor.getClass().getCanonicalName()); + } + + RPCProtos.ConnectionHeader header = headerBuilder.build(); + + + int totalSize = IPCUtil.getTotalSizeWhenWrittenDelimited(header); + + ByteBuf b = channel.alloc().heapBuffer(totalSize); + + b.writeInt(header.getSerializedSize()); + b.writeBytes(header.toByteArray()); + + return channel.writeAndFlush(b); + } + + /** + * Write request to channel + * + * @param call to write + */ + private void writeRequest(final AsyncCall call) { + try { + if (shouldCloseConnection) { + return; + } + + final RPCProtos.RequestHeader.Builder requestHeaderBuilder = RPCProtos.RequestHeader + .newBuilder(); + requestHeaderBuilder.setCallId(call.id) + .setMethodName(call.method.getName()).setRequestParam(call.param != null); + + if (Trace.isTracing()) { + Span s = Trace.currentSpan(); + requestHeaderBuilder.setTraceInfo(TracingProtos.RPCTInfo.newBuilder(). + setParentId(s.getSpanId()).setTraceId(s.getTraceId())); + } + + ByteBuffer cellBlock = client.buildCellBlock(call.controller.cellScanner()); + if (cellBlock != null) { + final RPCProtos.CellBlockMeta.Builder cellBlockBuilder = RPCProtos.CellBlockMeta + .newBuilder(); + cellBlockBuilder.setLength(cellBlock.limit()); + requestHeaderBuilder.setCellBlockMeta(cellBlockBuilder.build()); + } + // Only pass priority if there one. Let zero be same as no priority. + if (call.controller.getPriority() != 0) { + requestHeaderBuilder.setPriority(call.controller.getPriority()); + } + + RPCProtos.RequestHeader rh = requestHeaderBuilder.build(); + + int totalSize = IPCUtil.getTotalSizeWhenWrittenDelimited(rh, call.param); + if (cellBlock != null) { + totalSize += cellBlock.remaining(); + } + + ByteBuf b = channel.alloc().heapBuffer(totalSize); + try(ByteBufOutputStream out = new ByteBufOutputStream(b)) { + IPCUtil.write(out, rh, call.param, cellBlock); + } + + channel.writeAndFlush(b).addListener(new CallWriteListener(this,call)); + } catch (IOException e) { + if (!shouldCloseConnection) { + close(e); + } + } + } + + /** + * Fail a call + * + * @param call to fail + * @param cause of fail + */ + void failCall(AsyncCall call, IOException cause) { + calls.remove(call.id); + call.setFailed(cause); + } + + /** + * Set up server authorization + * + * @throws java.io.IOException if auth setup failed + */ + private void setupAuthorization() throws IOException { + SecurityInfo securityInfo = SecurityInfo.getInfo(serviceName); + this.useSasl = client.userProvider.isHBaseSecurityEnabled(); + + this.token = null; + if (useSasl && securityInfo != null) { + AuthenticationProtos.TokenIdentifier.Kind tokenKind = securityInfo.getTokenKind(); + if (tokenKind != null) { + TokenSelector tokenSelector = tokenHandlers.get(tokenKind); + if (tokenSelector != null) { + token = tokenSelector + .selectToken(new Text(client.clusterId), ticket.getUGI().getTokens()); + } else if (LOG.isDebugEnabled()) { + LOG.debug("No token selector found for type " + tokenKind); + } + } + String serverKey = securityInfo.getServerPrincipal(); + if (serverKey == null) { + throw new IOException("Can't obtain server Kerberos config key from SecurityInfo"); + } + this.serverPrincipal = SecurityUtil.getServerPrincipal(client.conf.get(serverKey), + address.getAddress().getCanonicalHostName().toLowerCase()); + if (LOG.isDebugEnabled()) { + LOG.debug("RPC Server Kerberos principal name for service=" + serviceName + " is " + + serverPrincipal); + } + } + + if (!useSasl) { + authMethod = AuthMethod.SIMPLE; + } else if (token != null) { + authMethod = AuthMethod.DIGEST; + } else { + authMethod = AuthMethod.KERBEROS; + } + + if (LOG.isDebugEnabled()) { + LOG.debug("Use " + authMethod + " authentication for service " + serviceName + + ", sasl=" + useSasl); + } + reloginMaxBackoff = client.conf.getInt("hbase.security.relogin.maxbackoff", 5000); + } + + /** + * Build the user information + * + * @param ugi User Group Information + * @param authMethod Authorization method + * @return UserInformation protobuf + */ + private RPCProtos.UserInformation buildUserInfo(UserGroupInformation ugi, AuthMethod authMethod) { + if (ugi == null || authMethod == AuthMethod.DIGEST) { + // Don't send user for token auth + return null; + } + RPCProtos.UserInformation.Builder userInfoPB = RPCProtos.UserInformation.newBuilder(); + if (authMethod == AuthMethod.KERBEROS) { + // Send effective user for Kerberos auth + userInfoPB.setEffectiveUser(ugi.getUserName()); + } else if (authMethod == AuthMethod.SIMPLE) { + //Send both effective user and real user for simple auth + userInfoPB.setEffectiveUser(ugi.getUserName()); + if (ugi.getRealUser() != null) { + userInfoPB.setRealUser(ugi.getRealUser().getUserName()); + } + } + return userInfoPB.build(); + } + + /** + * Create connection preamble + * + * @param byteBuf to write to + * @param authMethod to write + */ + private void createPreamble(ByteBuf byteBuf, AuthMethod authMethod) { + byteBuf.writeBytes(HConstants.RPC_HEADER); + byteBuf.writeByte(HConstants.RPC_CURRENT_VERSION); + byteBuf.writeByte(authMethod.code); + } + + /** + * Close connection + * + * @param e exception on close + */ + public void close(final Throwable e) { + client.removeConnection(ConnectionId.hashCode(ticket,serviceName,address)); + + // Move closing from the requesting thread to the channel thread + channel.eventLoop().execute(new Runnable() { + @Override + public void run() { + if (shouldCloseConnection) { + return; + } + + shouldCloseConnection = true; + + if (e != null) { + if (e instanceof IOException) { + closeException = (IOException) e; + } else { + closeException = new IOException(e); + } + } + + // log the info + if (LOG.isDebugEnabled() && closeException != null) { + LOG.debug(name + ": closing ipc connection to " + address + ": " + + closeException.getMessage()); + } + + cleanupCalls(true); + channel.disconnect().addListener(ChannelFutureListener.CLOSE); + + if (LOG.isDebugEnabled()) { + LOG.debug(name + ": closed"); + } + } + }); + } + + /** + * Clean up calls. + * + * @param cleanAll true if all calls should be cleaned, false for only the timed out calls + */ + public void cleanupCalls(boolean cleanAll) { + // Cancel outstanding timers + if (cleanupTimer != null) { + cleanupTimer.cancel(); + cleanupTimer = null; + } + + if (cleanAll) { + for (AsyncCall call : calls.values()) { + synchronized (call) { + // Calls can be done on another thread so check before failing them + if(!call.isDone()) { + if (closeException == null) { + failCall(call, new ConnectionClosingException("Call id=" + call.id + + " on server " + address + " aborted: connection is closing")); + } else { + failCall(call, closeException); + } + } + } + } + } else { + for (AsyncCall call : calls.values()) { + long waitTime = EnvironmentEdgeManager.currentTime() - call.getStartTime(); + long timeout = call.getRpcTimeout(); + if (timeout > 0 && waitTime >= timeout) { + synchronized (call) { + // Calls can be done on another thread so check before failing them + if (!call.isDone()) { + closeException = new CallTimeoutException("Call id=" + call.id + + ", waitTime=" + waitTime + ", rpcTimeout=" + timeout); + failCall(call, closeException); + } + } + } else { + // We expect the call to be ordered by timeout. It may not be the case, but stopping + // at the first valid call allows to be sure that we still have something to do without + // spending too much time by reading the full list. + break; + } + } + + if (!calls.isEmpty()) { + AsyncCall firstCall = calls.firstEntry().getValue(); + + final long newTimeout; + long maxWaitTime = EnvironmentEdgeManager.currentTime() - firstCall.getStartTime(); + if (maxWaitTime < firstCall.getRpcTimeout()) { + newTimeout = firstCall.getRpcTimeout() - maxWaitTime; + } else { + newTimeout = 0; + } + + closeException = null; + cleanupTimer = AsyncRpcClient.WHEEL_TIMER.newTimeout(timeoutTask, + newTimeout, TimeUnit.MILLISECONDS); + } + } + } + + /** + * Check if the connection is alive + * + * @return true if alive + */ + public boolean isAlive() { + return channel.isOpen(); + } + + /** + * Check if user should authenticate over Kerberos + * + * @return true if should be authenticated over Kerberos + * @throws java.io.IOException on failure of check + */ + private synchronized boolean shouldAuthenticateOverKrb() throws IOException { + UserGroupInformation loginUser = UserGroupInformation.getLoginUser(); + UserGroupInformation currentUser = UserGroupInformation.getCurrentUser(); + UserGroupInformation realUser = currentUser.getRealUser(); + return authMethod == AuthMethod.KERBEROS && + loginUser != null && + //Make sure user logged in using Kerberos either keytab or TGT + loginUser.hasKerberosCredentials() && + // relogin only in case it is the login user (e.g. JT) + // or superuser (like oozie). + (loginUser.equals(currentUser) || loginUser.equals(realUser)); + } + + /** + * If multiple clients with the same principal try to connect + * to the same server at the same time, the server assumes a + * replay attack is in progress. This is a feature of kerberos. + * In order to work around this, what is done is that the client + * backs off randomly and tries to initiate the connection + * again. + * The other problem is to do with ticket expiry. To handle that, + * a relogin is attempted. + *

+ * The retry logic is governed by the {@link #shouldAuthenticateOverKrb} + * method. In case when the user doesn't have valid credentials, we don't + * need to retry (from cache or ticket). In such cases, it is prudent to + * throw a runtime exception when we receive a SaslException from the + * underlying authentication implementation, so there is no retry from + * other high level (for eg, HCM or HBaseAdmin). + *

+ * + * @param currRetries retry count + * @param ex exception describing fail + * @param user which is trying to connect + * @throws java.io.IOException if IO fail + * @throws InterruptedException if thread is interrupted + */ + private void handleSaslConnectionFailure(final int currRetries, final Throwable ex, + final UserGroupInformation user) throws IOException, InterruptedException { + user.doAs(new PrivilegedExceptionAction() { + public Void run() throws IOException, InterruptedException { + if (shouldAuthenticateOverKrb()) { + if (currRetries < MAX_SASL_RETRIES) { + LOG.debug("Exception encountered while connecting to the server : " + ex); + //try re-login + if (UserGroupInformation.isLoginKeytabBased()) { + UserGroupInformation.getLoginUser().reloginFromKeytab(); + } else { + UserGroupInformation.getLoginUser().reloginFromTicketCache(); + } + + // Should reconnect + return null; + } else { + String msg = "Couldn't setup connection for " + + UserGroupInformation.getLoginUser().getUserName() + + " to " + serverPrincipal; + LOG.warn(msg); + throw (IOException) new IOException(msg).initCause(ex); + } + } else { + LOG.warn("Exception encountered while connecting to " + + "the server : " + ex); + } + if (ex instanceof RemoteException) { + throw (RemoteException) ex; + } + if (ex instanceof SaslException) { + String msg = "SASL authentication failed." + + " The most likely cause is missing or invalid credentials." + + " Consider 'kinit'."; + LOG.fatal(msg, ex); + throw new RuntimeException(msg, ex); + } + throw new IOException(ex); + } + }); + } + + @Override + public String toString() { + return this.address.toString() + "/" + this.serviceName + "/" + this.ticket; + } + + /** + * Listens to call writes and fails if write failed + */ + private static final class CallWriteListener implements ChannelFutureListener { + private final AsyncRpcChannel rpcChannel; + private final AsyncCall call; + + public CallWriteListener(AsyncRpcChannel asyncRpcChannel, AsyncCall call) { + this.rpcChannel = asyncRpcChannel; + this.call = call; + } + + @Override public void operationComplete(ChannelFuture future) throws Exception { + if (!future.isSuccess()) { + if(!this.call.isDone()) { + if (future.cause() instanceof IOException) { + rpcChannel.failCall(call, (IOException) future.cause()); + } else { + rpcChannel.failCall(call, new IOException(future.cause())); + } + } + } + } + } +} diff --git a/hbase-client/src/main/java/org/apache/hadoop/hbase/ipc/AsyncRpcClient.java b/hbase-client/src/main/java/org/apache/hadoop/hbase/ipc/AsyncRpcClient.java new file mode 100644 index 0000000..81cab67 --- /dev/null +++ b/hbase-client/src/main/java/org/apache/hadoop/hbase/ipc/AsyncRpcClient.java @@ -0,0 +1,385 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.hadoop.hbase.ipc; + +import com.google.common.annotations.VisibleForTesting; +import com.google.protobuf.Descriptors; +import com.google.protobuf.Message; +import com.google.protobuf.RpcCallback; +import com.google.protobuf.RpcChannel; +import com.google.protobuf.RpcController; +import io.netty.bootstrap.Bootstrap; +import io.netty.buffer.PooledByteBufAllocator; +import io.netty.channel.ChannelInitializer; +import io.netty.channel.ChannelOption; +import io.netty.channel.nio.NioEventLoopGroup; +import io.netty.channel.socket.SocketChannel; +import io.netty.channel.socket.nio.NioSocketChannel; +import io.netty.util.HashedWheelTimer; +import io.netty.util.ResourceLeakDetector; +import io.netty.util.concurrent.Future; +import io.netty.util.concurrent.GenericFutureListener; +import io.netty.util.concurrent.Promise; +import org.apache.hadoop.conf.Configuration; +import org.apache.hadoop.hbase.CellScanner; +import org.apache.hadoop.hbase.HConstants; +import org.apache.hadoop.hbase.ServerName; +import org.apache.hadoop.hbase.classification.InterfaceAudience; +import org.apache.hadoop.hbase.security.User; +import org.apache.hadoop.hbase.util.Pair; +import org.apache.hadoop.hbase.util.PoolMap; +import org.apache.hadoop.hbase.util.Threads; + +import java.io.IOException; +import java.net.InetSocketAddress; +import java.net.SocketAddress; +import java.nio.ByteBuffer; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicInteger; + +/** + * Netty client for the requests and responses + */ +@InterfaceAudience.Private +public class AsyncRpcClient extends AbstractRpcClient { + + public static final HashedWheelTimer WHEEL_TIMER = + new HashedWheelTimer(100, TimeUnit.MILLISECONDS); + + private static final ChannelInitializer DEFAULT_CHANNEL_INITIALIZER = + new ChannelInitializer() { + @Override + protected void initChannel(SocketChannel ch) throws Exception { + //empty initializer + } + }; + + protected final AtomicInteger callIdCnt = new AtomicInteger(); + + private final NioEventLoopGroup eventLoopGroup; + private final PoolMap connections; + + final FailedServers failedServers; + + private final Bootstrap bootstrap; + + /** + * Constructor for tests + * + * @param configuration to HBase + * @param clusterId for the cluster + * @param localAddress local address to connect to + * @param channelInitializer for custom channel handlers + */ + @VisibleForTesting + AsyncRpcClient(Configuration configuration, String clusterId, SocketAddress localAddress, + ChannelInitializer channelInitializer) { + super(configuration, clusterId, localAddress); + + if (LOG.isDebugEnabled()) { + LOG.debug("Starting async Hbase RPC client"); + } + + // Max amount of threads to use. 0 lets Netty decide based on amount of cores + int maxThreads = conf.getInt("hbase.rpc.client.threads.max", 0); + + this.eventLoopGroup = new NioEventLoopGroup(maxThreads, + Threads.newDaemonThreadFactory("AsyncRpcChannel")); + + this.connections = new PoolMap<>(getPoolType(configuration), getPoolSize(configuration)); + this.failedServers = new FailedServers(configuration); + + int operationTimeout = configuration.getInt(HConstants.HBASE_CLIENT_OPERATION_TIMEOUT, + HConstants.DEFAULT_HBASE_CLIENT_OPERATION_TIMEOUT); + + // Enable to detect ByteBuf leaks + ResourceLeakDetector.setLevel(ResourceLeakDetector.Level.DISABLED); + + // Configure the default bootstrap. + this.bootstrap = new Bootstrap(); + bootstrap.group(eventLoopGroup).channel(NioSocketChannel.class) + .option(ChannelOption.ALLOCATOR, PooledByteBufAllocator.DEFAULT) + .option(ChannelOption.TCP_NODELAY, tcpNoDelay) + .option(ChannelOption.SO_KEEPALIVE, tcpKeepAlive) + .option(ChannelOption.CONNECT_TIMEOUT_MILLIS, operationTimeout); + if (channelInitializer == null) { + channelInitializer = DEFAULT_CHANNEL_INITIALIZER; + } + bootstrap.handler(channelInitializer); + if (localAddress != null) { + bootstrap.localAddress(localAddress); + } + } + + /** + * Constructor + * + * @param configuration to HBase + * @param clusterId for the cluster + * @param localAddress local address to connect to + */ + public AsyncRpcClient(Configuration configuration, String clusterId, SocketAddress localAddress) { + this(configuration, clusterId, localAddress, null); + } + + /** + * Make a call, passing param, to the IPC server running at + * address which is servicing the protocol protocol, + * with the ticket credentials, returning the value. + * Throws exceptions if there are network problems or if the remote code + * threw an exception. + * + * @param ticket Be careful which ticket you pass. A new user will mean a new Connection. + * {@link org.apache.hadoop.hbase.security.UserProvider#getCurrent()} makes a new + * instance of User each time so will be a new Connection each time. + * @return A pair with the Message response and the Cell data (if any). + * @throws InterruptedException if call is interrupted + * @throws java.io.IOException if a connection failure is encountered + */ + @Override protected Pair call(PayloadCarryingRpcController pcrc, + Descriptors.MethodDescriptor md, Message param, Message returnType, User ticket, + InetSocketAddress addr) throws IOException, InterruptedException { + + final AsyncRpcChannel connection = createRpcChannel(md.getService().getName(), addr, ticket); + + Promise promise = connection.callMethodWithPromise(md, pcrc, param, returnType); + + try { + Message response = promise.get(); + return new Pair<>(response, pcrc.cellScanner()); + } catch (ExecutionException e) { + if (e.getCause() instanceof IOException) { + throw (IOException) e.getCause(); + } else { + throw new IOException(e.getCause()); + } + } + } + + /** + * Call method async + */ + private void callMethod(Descriptors.MethodDescriptor md, final PayloadCarryingRpcController pcrc, + Message param, Message returnType, User ticket, InetSocketAddress addr, + final RpcCallback done) { + final AsyncRpcChannel connection; + try { + connection = createRpcChannel(md.getService().getName(), addr, ticket); + + connection.callMethod(md, pcrc, param, returnType).addListener( + new GenericFutureListener>() { + @Override + public void operationComplete(Future future) throws Exception { + if(!future.isSuccess()){ + Throwable cause = future.cause(); + if (cause instanceof IOException) { + pcrc.setFailed((IOException) cause); + }else{ + pcrc.setFailed(new IOException(cause)); + } + }else{ + try { + done.run(future.get()); + }catch (ExecutionException e){ + Throwable cause = e.getCause(); + if (cause instanceof IOException) { + pcrc.setFailed((IOException) cause); + }else{ + pcrc.setFailed(new IOException(cause)); + } + }catch (InterruptedException e){ + pcrc.setFailed(new IOException(e)); + } + } + } + }); + } catch (StoppedRpcClientException|FailedServerException e) { + pcrc.setFailed(e); + } + } + + /** + * Close netty + */ + public void close() { + if (LOG.isDebugEnabled()) { + LOG.debug("Stopping async HBase RPC client"); + } + + synchronized (connections) { + for (AsyncRpcChannel conn : connections.values()) { + conn.close(null); + } + } + + eventLoopGroup.shutdownGracefully(); + } + + /** + * Create a cell scanner + * + * @param cellBlock to create scanner for + * @return CellScanner + * @throws java.io.IOException on error on creation cell scanner + */ + public CellScanner createCellScanner(byte[] cellBlock) throws IOException { + return ipcUtil.createCellScanner(this.codec, this.compressor, cellBlock); + } + + /** + * Build cell block + * + * @param cells to create block with + * @return ByteBuffer with cells + * @throws java.io.IOException if block creation fails + */ + public ByteBuffer buildCellBlock(CellScanner cells) throws IOException { + return ipcUtil.buildCellBlock(this.codec, this.compressor, cells); + } + + /** + * Creates an RPC client + * + * @param serviceName name of servicce + * @param location to connect to + * @param ticket for current user + * @return new RpcChannel + * @throws StoppedRpcClientException when Rpc client is stopped + * @throws FailedServerException if server failed + */ + private AsyncRpcChannel createRpcChannel(String serviceName, InetSocketAddress location, + User ticket) throws StoppedRpcClientException, FailedServerException { + if (this.eventLoopGroup.isShuttingDown() || this.eventLoopGroup.isShutdown()) { + throw new StoppedRpcClientException(); + } + + // Check if server is failed + if (this.failedServers.isFailedServer(location)) { + if (LOG.isDebugEnabled()) { + LOG.debug("Not trying to connect to " + location + + " this server is in the failed servers list"); + } + throw new FailedServerException( + "This server is in the failed servers list: " + location); + } + + int hashCode = ConnectionId.hashCode(ticket,serviceName,location); + + AsyncRpcChannel rpcChannel; + synchronized (connections) { + rpcChannel = connections.get(hashCode); + if (rpcChannel == null) { + rpcChannel = new AsyncRpcChannel(this.bootstrap, this, ticket, serviceName, location); + connections.put(hashCode, rpcChannel); + } + } + + return rpcChannel; + } + + /** + * Interrupt the connections to the given ip:port server. This should be called if the server + * is known as actually dead. This will not prevent current operation to be retried, and, + * depending on their own behavior, they may retry on the same server. This can be a feature, + * for example at startup. In any case, they're likely to get connection refused (if the + * process died) or no route to host: i.e. there next retries should be faster and with a + * safe exception. + * + * @param sn server to cancel connections for + */ + @Override + public void cancelConnections(ServerName sn) { + synchronized (connections) { + for (AsyncRpcChannel rpcChannel : connections.values()) { + if (rpcChannel.isAlive() && + rpcChannel.address.getPort() == sn.getPort() && + rpcChannel.address.getHostName().contentEquals(sn.getHostname())) { + LOG.info("The server on " + sn.toString() + + " is dead - stopping the connection " + rpcChannel.toString()); + rpcChannel.close(null); + } + } + } + } + + /** + * Remove connection from pool + * + * @param connectionHashCode of connection + */ + public void removeConnection(int connectionHashCode) { + synchronized (connections) { + this.connections.remove(connectionHashCode); + } + } + + /** + * Creates a "channel" that can be used by a protobuf service. Useful setting up + * protobuf stubs. + * + * @param sn server name describing location of server + * @param user which is to use the connection + * @param rpcTimeout default rpc operation timeout + * + * @return A rpc channel that goes via this rpc client instance. + * @throws IOException when channel could not be created + */ + public RpcChannel createRpcChannel(final ServerName sn, final User user, int rpcTimeout) { + return new RpcChannelImplementation(this, sn, user, rpcTimeout); + } + + /** + * Blocking rpc channel that goes via hbase rpc. + */ + @VisibleForTesting + public static class RpcChannelImplementation implements RpcChannel { + private final InetSocketAddress isa; + private final AsyncRpcClient rpcClient; + private final User ticket; + private final int channelOperationTimeout; + + /** + * @param channelOperationTimeout - the default timeout when no timeout is given + */ + protected RpcChannelImplementation(final AsyncRpcClient rpcClient, + final ServerName sn, final User ticket, int channelOperationTimeout) { + this.isa = new InetSocketAddress(sn.getHostname(), sn.getPort()); + this.rpcClient = rpcClient; + this.ticket = ticket; + this.channelOperationTimeout = channelOperationTimeout; + } + + @Override + public void callMethod(Descriptors.MethodDescriptor md, RpcController controller, + Message param, Message returnType, RpcCallback done) { + PayloadCarryingRpcController pcrc; + if (controller != null) { + pcrc = (PayloadCarryingRpcController) controller; + if (!pcrc.hasCallTimeout()) { + pcrc.setCallTimeout(channelOperationTimeout); + } + } else { + pcrc = new PayloadCarryingRpcController(); + pcrc.setCallTimeout(channelOperationTimeout); + } + + this.rpcClient.callMethod(md, pcrc, param, returnType, this.ticket, this.isa, done); + } + } +} \ No newline at end of file diff --git a/hbase-client/src/main/java/org/apache/hadoop/hbase/ipc/AsyncServerResponseHandler.java b/hbase-client/src/main/java/org/apache/hadoop/hbase/ipc/AsyncServerResponseHandler.java new file mode 100644 index 0000000..d71bf5e --- /dev/null +++ b/hbase-client/src/main/java/org/apache/hadoop/hbase/ipc/AsyncServerResponseHandler.java @@ -0,0 +1,130 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.hadoop.hbase.ipc; + +import com.google.protobuf.Message; +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufInputStream; +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.ChannelInboundHandlerAdapter; +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; +import org.apache.hadoop.hbase.CellScanner; +import org.apache.hadoop.hbase.classification.InterfaceAudience; +import org.apache.hadoop.hbase.protobuf.generated.RPCProtos; +import org.apache.hadoop.ipc.RemoteException; + +import java.io.IOException; + +/** + * Handles Hbase responses + */ +@InterfaceAudience.Private +public class AsyncServerResponseHandler extends ChannelInboundHandlerAdapter { + public static final Log LOG = LogFactory.getLog(AsyncServerResponseHandler.class.getName()); + + private final AsyncRpcChannel channel; + + /** + * Constructor + * + * @param channel on which this response handler operates + */ + public AsyncServerResponseHandler(AsyncRpcChannel channel) { + this.channel = channel; + } + + @Override public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception { + ByteBuf inBuffer = (ByteBuf) msg; + ByteBufInputStream in = new ByteBufInputStream(inBuffer); + + if (channel.shouldCloseConnection) { + return; + } + int totalSize = inBuffer.readableBytes(); + try { + // Read the header + RPCProtos.ResponseHeader responseHeader = RPCProtos.ResponseHeader.parseDelimitedFrom(in); + int id = responseHeader.getCallId(); + AsyncCall call = channel.calls.get(id); + if (call == null) { + // So we got a response for which we have no corresponding 'call' here on the client-side. + // We probably timed out waiting, cleaned up all references, and now the server decides + // to return a response. There is nothing we can do w/ the response at this stage. Clean + // out the wire of the response so its out of the way and we can get other responses on + // this connection. + int readSoFar = IPCUtil.getTotalSizeWhenWrittenDelimited(responseHeader); + int whatIsLeftToRead = totalSize - readSoFar; + + // This is done through a Netty ByteBuf which has different behavior than InputStream. + // It does not return number of bytes read but will update pointer internally and throws an + // exception when too many bytes are to be skipped. + inBuffer.skipBytes(whatIsLeftToRead); + return; + } + + if (responseHeader.hasException()) { + RPCProtos.ExceptionResponse exceptionResponse = responseHeader.getException(); + RemoteException re = createRemoteException(exceptionResponse); + if (exceptionResponse.getExceptionClassName(). + equals(FatalConnectionException.class.getName())) { + channel.close(re); + } else { + channel.failCall(call, re); + } + } else { + Message value = null; + // Call may be null because it may have timedout and been cleaned up on this side already + if (call.responseDefaultType != null) { + Message.Builder builder = call.responseDefaultType.newBuilderForType(); + builder.mergeDelimitedFrom(in); + value = builder.build(); + } + CellScanner cellBlockScanner = null; + if (responseHeader.hasCellBlockMeta()) { + int size = responseHeader.getCellBlockMeta().getLength(); + byte[] cellBlock = new byte[size]; + inBuffer.readBytes(cellBlock, 0, cellBlock.length); + cellBlockScanner = channel.client.createCellScanner(cellBlock); + } + call.setSuccess(value, cellBlockScanner); + } + channel.calls.remove(id); + } catch (IOException e) { + // Treat this as a fatal condition and close this connection + channel.close(e); + } finally { + inBuffer.release(); + channel.cleanupCalls(false); + } + } + + /** + * @param e Proto exception + * @return RemoteException made from passed e + */ + private RemoteException createRemoteException(final RPCProtos.ExceptionResponse e) { + String innerExceptionClassName = e.getExceptionClassName(); + boolean doNotRetry = e.getDoNotRetry(); + return e.hasHostname() ? + // If a hostname then add it to the RemoteWithExtrasException + new RemoteWithExtrasException(innerExceptionClassName, e.getStackTrace(), e.getHostname(), + e.getPort(), doNotRetry) : + new RemoteWithExtrasException(innerExceptionClassName, e.getStackTrace(), doNotRetry); + } +} \ No newline at end of file diff --git a/hbase-client/src/main/java/org/apache/hadoop/hbase/ipc/ConnectionId.java b/hbase-client/src/main/java/org/apache/hadoop/hbase/ipc/ConnectionId.java index a62d415..bbd2fc7 100644 --- a/hbase-client/src/main/java/org/apache/hadoop/hbase/ipc/ConnectionId.java +++ b/hbase-client/src/main/java/org/apache/hadoop/hbase/ipc/ConnectionId.java @@ -28,10 +28,10 @@ import java.net.InetSocketAddress; */ @InterfaceAudience.Private public class ConnectionId { - final InetSocketAddress address; - final User ticket; private static final int PRIME = 16777619; + final User ticket; final String serviceName; + final InetSocketAddress address; public ConnectionId(User ticket, String serviceName, InetSocketAddress address) { this.address = address; @@ -70,9 +70,12 @@ public class ConnectionId { @Override // simply use the default Object#hashcode() ? public int hashCode() { - int hashcode = (address.hashCode() + - PRIME * (PRIME * this.serviceName.hashCode() ^ - (ticket == null ? 0 : ticket.hashCode()))); - return hashcode; + return hashCode(ticket,serviceName,address); + } + + public static int hashCode(User ticket, String serviceName, InetSocketAddress address){ + return (address.hashCode() + + PRIME * (PRIME * serviceName.hashCode() ^ + (ticket == null ? 0 : ticket.hashCode()))); } } diff --git a/hbase-client/src/main/java/org/apache/hadoop/hbase/ipc/PayloadCarryingRpcController.java b/hbase-client/src/main/java/org/apache/hadoop/hbase/ipc/PayloadCarryingRpcController.java index ba7ecf8..a700dcb 100644 --- a/hbase-client/src/main/java/org/apache/hadoop/hbase/ipc/PayloadCarryingRpcController.java +++ b/hbase-client/src/main/java/org/apache/hadoop/hbase/ipc/PayloadCarryingRpcController.java @@ -42,7 +42,7 @@ public class PayloadCarryingRpcController */ // Currently only multi call makes use of this. Eventually this should be only way to set // priority. - private int priority = 0; + private int priority = HConstants.NORMAL_QOS; /** * They are optionally set on construction, cleared after we make the call, and then optionally diff --git a/hbase-client/src/main/java/org/apache/hadoop/hbase/ipc/RpcClient.java b/hbase-client/src/main/java/org/apache/hadoop/hbase/ipc/RpcClient.java index 4ededd2..cf689f5 100644 --- a/hbase-client/src/main/java/org/apache/hadoop/hbase/ipc/RpcClient.java +++ b/hbase-client/src/main/java/org/apache/hadoop/hbase/ipc/RpcClient.java @@ -23,6 +23,7 @@ import org.apache.hadoop.hbase.classification.InterfaceAudience; import org.apache.hadoop.hbase.security.User; import java.io.Closeable; +import java.io.IOException; /** * Interface for RpcClient implementations so ConnectionManager can handle it. @@ -56,9 +57,15 @@ import java.io.Closeable; * Creates a "channel" that can be used by a blocking protobuf service. Useful setting up * protobuf blocking stubs. * + * @param sn server name describing location of server + * @param user which is to use the connection + * @param rpcTimeout default rpc operation timeout + * * @return A blocking rpc channel that goes via this rpc client instance. + * @throws IOException when channel could not be created */ - public BlockingRpcChannel createBlockingRpcChannel(ServerName sn, User user, int rpcTimeout); + public BlockingRpcChannel createBlockingRpcChannel(ServerName sn, User user, + int rpcTimeout) throws IOException; /** * Interrupt the connections to the given server. This should be called if the server @@ -67,6 +74,7 @@ import java.io.Closeable; * for example at startup. In any case, they're likely to get connection refused (if the * process died) or no route to host: i.e. their next retries should be faster and with a * safe exception. + * @param sn server location to cancel connections of */ public void cancelConnections(ServerName sn); diff --git a/hbase-client/src/main/java/org/apache/hadoop/hbase/ipc/RpcClientFactory.java b/hbase-client/src/main/java/org/apache/hadoop/hbase/ipc/RpcClientFactory.java index 2dbb776..10ddc56 100644 --- a/hbase-client/src/main/java/org/apache/hadoop/hbase/ipc/RpcClientFactory.java +++ b/hbase-client/src/main/java/org/apache/hadoop/hbase/ipc/RpcClientFactory.java @@ -59,8 +59,7 @@ public final class RpcClientFactory { public static RpcClient createClient(Configuration conf, String clusterId, SocketAddress localAddr) { String rpcClientClass = - conf.get(CUSTOM_RPC_CLIENT_IMPL_CONF_KEY, - RpcClientImpl.class.getName()); + conf.get(CUSTOM_RPC_CLIENT_IMPL_CONF_KEY, AsyncRpcClient.class.getName()); return ReflectionUtils.instantiateWithCustomCtor( rpcClientClass, new Class[] { Configuration.class, String.class, SocketAddress.class }, diff --git a/hbase-client/src/main/java/org/apache/hadoop/hbase/ipc/RpcClientImpl.java b/hbase-client/src/main/java/org/apache/hadoop/hbase/ipc/RpcClientImpl.java index 97fa475..ff5f297 100644 --- a/hbase-client/src/main/java/org/apache/hadoop/hbase/ipc/RpcClientImpl.java +++ b/hbase-client/src/main/java/org/apache/hadoop/hbase/ipc/RpcClientImpl.java @@ -787,9 +787,9 @@ public class RpcClientImpl extends AbstractRpcClient { // up the reading on occasion (the passed in stream is not buffered yet). // Preamble is six bytes -- 'HBas' + VERSION + AUTH_CODE - int rpcHeaderLen = HConstants.RPC_HEADER.array().length; + int rpcHeaderLen = HConstants.RPC_HEADER.length; byte [] preamble = new byte [rpcHeaderLen + 2]; - System.arraycopy(HConstants.RPC_HEADER.array(), 0, preamble, 0, rpcHeaderLen); + System.arraycopy(HConstants.RPC_HEADER, 0, preamble, 0, rpcHeaderLen); preamble[rpcHeaderLen] = HConstants.RPC_CURRENT_VERSION; preamble[rpcHeaderLen + 1] = authMethod.code; outStream.write(preamble); @@ -1120,14 +1120,6 @@ public class RpcClientImpl extends AbstractRpcClient { } } - Pair call(PayloadCarryingRpcController pcrc, - MethodDescriptor md, Message param, CellScanner cells, - Message returnType, User ticket, InetSocketAddress addr, int rpcTimeout) - throws InterruptedException, IOException { - return - call(pcrc, md, param, cells, returnType, ticket, addr, rpcTimeout, HConstants.NORMAL_QOS); - } - /** Make a call, passing param, to the IPC server running at * address which is servicing the protocol protocol, * with the ticket credentials, returning the value. @@ -1140,21 +1132,22 @@ public class RpcClientImpl extends AbstractRpcClient { * @throws InterruptedException * @throws IOException */ - @Override protected Pair call(PayloadCarryingRpcController pcrc, MethodDescriptor md, - Message param, CellScanner cells, - Message returnType, User ticket, InetSocketAddress addr, int callTimeout, int priority) + Message param, Message returnType, User ticket, InetSocketAddress addr) throws IOException, InterruptedException { - final Call call = new Call( - this.callIdCnt.getAndIncrement(), - md, param, cells, returnType, callTimeout); + if (pcrc == null) { + pcrc = new PayloadCarryingRpcController(); + } + CellScanner cells = pcrc.cellScanner(); - final Connection connection = getConnection(ticket, call, addr, this.codec, this.compressor); + final Call call = new Call(this.callIdCnt.getAndIncrement(), md, param, cells, returnType, + pcrc.getCallTimeout()); + + final Connection connection = getConnection(ticket, call, addr); final CallFuture cts; if (connection.callSender != null) { - cts = connection.callSender.sendCall(call, priority, Trace.currentSpan()); - if (pcrc != null) { + cts = connection.callSender.sendCall(call, pcrc.getPriority(), Trace.currentSpan()); pcrc.notifyOnCancel(new RpcCallback() { @Override public void run(Object parameter) { @@ -1166,11 +1159,9 @@ public class RpcClientImpl extends AbstractRpcClient { call.callComplete(); return new Pair(call.response, call.cells); } - } - } else { cts = null; - connection.tracedWriteRequest(call, priority, Trace.currentSpan()); + connection.tracedWriteRequest(call, pcrc.getPriority(), Trace.currentSpan()); } while (!call.done) { @@ -1265,8 +1256,7 @@ public class RpcClientImpl extends AbstractRpcClient { * Get a connection from the pool, or create a new one and add it to the * pool. Connections to a given host/port are reused. */ - protected Connection getConnection(User ticket, Call call, InetSocketAddress addr, - final Codec codec, final CompressionCodec compressor) + protected Connection getConnection(User ticket, Call call, InetSocketAddress addr) throws IOException { if (!running.get()) throw new StoppedRpcClientException(); Connection connection; diff --git a/hbase-client/src/main/java/org/apache/hadoop/hbase/ipc/TimeLimitedRpcController.java b/hbase-client/src/main/java/org/apache/hadoop/hbase/ipc/TimeLimitedRpcController.java index 94b743f..de502cb 100644 --- a/hbase-client/src/main/java/org/apache/hadoop/hbase/ipc/TimeLimitedRpcController.java +++ b/hbase-client/src/main/java/org/apache/hadoop/hbase/ipc/TimeLimitedRpcController.java @@ -42,8 +42,12 @@ public class TimeLimitedRpcController implements RpcController { private IOException exception; - public Integer getCallTimeout() { - return callTimeout; + public int getCallTimeout() { + if (callTimeout != null) { + return callTimeout; + } else { + return 0; + } } public void setCallTimeout(int callTimeout) { diff --git a/hbase-client/src/main/java/org/apache/hadoop/hbase/security/HBaseSaslRpcClient.java b/hbase-client/src/main/java/org/apache/hadoop/hbase/security/HBaseSaslRpcClient.java index 8f6e8e1..112ec9d 100644 --- a/hbase-client/src/main/java/org/apache/hadoop/hbase/security/HBaseSaslRpcClient.java +++ b/hbase-client/src/main/java/org/apache/hadoop/hbase/security/HBaseSaslRpcClient.java @@ -117,7 +117,7 @@ public class HBaseSaslRpcClient { throw new IOException( "Failed to specify server's Kerberos principal name"); } - String names[] = SaslUtil.splitKerberosName(serverPrincipal); + String[] names = SaslUtil.splitKerberosName(serverPrincipal); if (names.length != 3) { throw new IOException( "Kerberos principal does not have the expected format: " diff --git a/hbase-client/src/main/java/org/apache/hadoop/hbase/security/SaslClientHandler.java b/hbase-client/src/main/java/org/apache/hadoop/hbase/security/SaslClientHandler.java new file mode 100644 index 0000000..cf92a63 --- /dev/null +++ b/hbase-client/src/main/java/org/apache/hadoop/hbase/security/SaslClientHandler.java @@ -0,0 +1,353 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.hadoop.hbase.security; + +import io.netty.buffer.ByteBuf; +import io.netty.channel.Channel; +import io.netty.channel.ChannelDuplexHandler; +import io.netty.channel.ChannelFuture; +import io.netty.channel.ChannelFutureListener; +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.ChannelPromise; +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; +import org.apache.hadoop.hbase.classification.InterfaceAudience; +import org.apache.hadoop.ipc.RemoteException; +import org.apache.hadoop.security.token.Token; +import org.apache.hadoop.security.token.TokenIdentifier; + +import javax.security.auth.callback.CallbackHandler; +import javax.security.sasl.Sasl; +import javax.security.sasl.SaslClient; +import javax.security.sasl.SaslException; +import java.io.IOException; +import java.nio.charset.Charset; +import java.util.Random; + +/** + * Handles Sasl connections + */ +@InterfaceAudience.Private +public class SaslClientHandler extends ChannelDuplexHandler { + public static final Log LOG = LogFactory.getLog(SaslClientHandler.class); + + private final boolean fallbackAllowed; + + /** + * Used for client or server's token to send or receive from each other. + */ + private final SaslClient saslClient; + private final SaslExceptionHandler exceptionHandler; + private final SaslSuccessfulConnectHandler successfulConnectHandler; + private byte[] saslToken; + private boolean firstRead = true; + + private int retryCount = 0; + private Random random; + + /** + * Constructor + * + * @param method auth method + * @param token for Sasl + * @param serverPrincipal Server's Kerberos principal name + * @param fallbackAllowed True if server may also fall back to less secure connection + * @param rpcProtection Quality of protection. Integrity or privacy + * @param exceptionHandler handler for exceptions + * @param successfulConnectHandler handler for succesful connects + * @throws java.io.IOException if handler could not be created + */ + public SaslClientHandler(AuthMethod method, Token token, + String serverPrincipal, boolean fallbackAllowed, String rpcProtection, + SaslExceptionHandler exceptionHandler, SaslSuccessfulConnectHandler successfulConnectHandler) + throws IOException { + this.fallbackAllowed = fallbackAllowed; + + this.exceptionHandler = exceptionHandler; + this.successfulConnectHandler = successfulConnectHandler; + + SaslUtil.initSaslProperties(rpcProtection); + switch (method) { + case DIGEST: + if (LOG.isDebugEnabled()) + LOG.debug("Creating SASL " + AuthMethod.DIGEST.getMechanismName() + + " client to authenticate to service at " + token.getService()); + saslClient = createDigestSaslClient(new String[] { AuthMethod.DIGEST.getMechanismName() }, + SaslUtil.SASL_DEFAULT_REALM, new HBaseSaslRpcClient.SaslClientCallbackHandler(token)); + break; + case KERBEROS: + if (LOG.isDebugEnabled()) { + LOG.debug("Creating SASL " + AuthMethod.KERBEROS.getMechanismName() + + " client. Server's Kerberos principal name is " + serverPrincipal); + } + if (serverPrincipal == null || serverPrincipal.isEmpty()) { + throw new IOException("Failed to specify server's Kerberos principal name"); + } + String[] names = SaslUtil.splitKerberosName(serverPrincipal); + if (names.length != 3) { + throw new IOException( + "Kerberos principal does not have the expected format: " + serverPrincipal); + } + saslClient = createKerberosSaslClient(new String[] { AuthMethod.KERBEROS.getMechanismName() }, + names[0], names[1]); + break; + default: + throw new IOException("Unknown authentication method " + method); + } + if (saslClient == null) + throw new IOException("Unable to find SASL client implementation"); + } + + /** + * Create a Digest Sasl client + * + * @param mechanismNames names of mechanisms + * @param saslDefaultRealm default realm for sasl + * @param saslClientCallbackHandler handler for the client + * @return new SaslClient + * @throws java.io.IOException if creation went wrong + */ + protected SaslClient createDigestSaslClient(String[] mechanismNames, String saslDefaultRealm, + CallbackHandler saslClientCallbackHandler) throws IOException { + return Sasl.createSaslClient(mechanismNames, null, null, saslDefaultRealm, SaslUtil.SASL_PROPS, + saslClientCallbackHandler); + } + + /** + * Create Kerberos client + * + * @param mechanismNames names of mechanisms + * @param userFirstPart first part of username + * @param userSecondPart second part of username + * @return new SaslClient + * @throws java.io.IOException if fails + */ + protected SaslClient createKerberosSaslClient(String[] mechanismNames, String userFirstPart, + String userSecondPart) throws IOException { + return Sasl + .createSaslClient(mechanismNames, null, userFirstPart, userSecondPart, SaslUtil.SASL_PROPS, + null); + } + + @Override public void channelUnregistered(ChannelHandlerContext ctx) throws Exception { + saslClient.dispose(); + } + + @Override public void handlerAdded(ChannelHandlerContext ctx) throws Exception { + this.saslToken = new byte[0]; + if (saslClient.hasInitialResponse()) { + saslToken = saslClient.evaluateChallenge(saslToken); + } + if (saslToken != null) { + writeSaslToken(ctx, saslToken); + if (LOG.isDebugEnabled()) { + LOG.debug("Have sent token of size " + saslToken.length + " from initSASLContext."); + } + } + } + + @Override public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception { + ByteBuf in = (ByteBuf) msg; + + // If not complete, try to negotiate + if (!saslClient.isComplete()) { + while (!saslClient.isComplete() && in.isReadable()) { + readStatus(in); + int len = in.readInt(); + if (firstRead) { + firstRead = false; + if (len == SaslUtil.SWITCH_TO_SIMPLE_AUTH) { + if (!fallbackAllowed) { + throw new IOException("Server asks us to fall back to SIMPLE auth, " + "but this " + + "client is configured to only allow secure connections."); + } + if (LOG.isDebugEnabled()) { + LOG.debug("Server asks us to fall back to simple auth."); + } + saslClient.dispose(); + + ctx.pipeline().remove(this); + successfulConnectHandler.onSuccess(ctx.channel()); + return; + } + } + saslToken = new byte[len]; + if (LOG.isDebugEnabled()) + LOG.debug("Will read input token of size " + saslToken.length + + " for processing by initSASLContext"); + in.readBytes(saslToken); + + saslToken = saslClient.evaluateChallenge(saslToken); + if (saslToken != null) { + if (LOG.isDebugEnabled()) + LOG.debug("Will send token of size " + saslToken.length + " from initSASLContext."); + writeSaslToken(ctx, saslToken); + } + } + + if (saslClient.isComplete()) { + String qop = (String) saslClient.getNegotiatedProperty(Sasl.QOP); + + if (LOG.isDebugEnabled()) { + LOG.debug("SASL client context established. Negotiated QoP: " + qop); + } + + boolean useWrap = qop != null && !"auth".equalsIgnoreCase(qop); + + if (!useWrap) { + ctx.pipeline().remove(this); + } + successfulConnectHandler.onSuccess(ctx.channel()); + } + } + // Normal wrapped reading + else { + try { + int length = in.readInt(); + if (LOG.isDebugEnabled()) { + LOG.debug("Actual length is " + length); + } + saslToken = new byte[length]; + in.readBytes(saslToken); + } catch (IndexOutOfBoundsException e) { + return; + } + try { + ByteBuf b = ctx.channel().alloc().heapBuffer(saslToken.length); + + b.writeBytes(saslClient.unwrap(saslToken, 0, saslToken.length)); + ctx.fireChannelRead(b); + + } catch (SaslException se) { + try { + saslClient.dispose(); + } catch (SaslException ignored) { + LOG.debug("Ignoring SASL exception", ignored); + } + throw se; + } + } + } + + /** + * Write SASL token + * + * @param ctx to write to + * @param saslToken to write + */ + private void writeSaslToken(final ChannelHandlerContext ctx, byte[] saslToken) { + ByteBuf b = ctx.alloc().heapBuffer(4 + saslToken.length); + b.writeInt(saslToken.length); + b.writeBytes(saslToken, 0, saslToken.length); + ctx.writeAndFlush(b).addListener(new ChannelFutureListener() { + @Override public void operationComplete(ChannelFuture future) throws Exception { + if (!future.isSuccess()) { + exceptionCaught(ctx, future.cause()); + } + } + }); + } + + /** + * Get the read status + * + * @param inStream to read + * @throws org.apache.hadoop.ipc.RemoteException if status was not success + */ + private static void readStatus(ByteBuf inStream) throws RemoteException { + int status = inStream.readInt(); // read status + if (status != SaslStatus.SUCCESS.state) { + throw new RemoteException(inStream.toString(Charset.forName("UTF-8")), + inStream.toString(Charset.forName("UTF-8"))); + } + } + + @Override public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) + throws Exception { + saslClient.dispose(); + + ctx.close(); + + if (this.random == null) { + this.random = new Random(); + } + exceptionHandler.handle(this.retryCount++, this.random, cause); + } + + @Override public void write(final ChannelHandlerContext ctx, Object msg, ChannelPromise promise) + throws Exception { + // If not complete, try to negotiate + if (!saslClient.isComplete()) { + super.write(ctx, msg, promise); + } else { + ByteBuf in = (ByteBuf) msg; + + try { + saslToken = saslClient.wrap(in.array(), in.readerIndex(), in.readableBytes()); + } catch (SaslException se) { + try { + saslClient.dispose(); + } catch (SaslException ignored) { + LOG.debug("Ignoring SASL exception", ignored); + } + promise.setFailure(se); + } + if (saslToken != null) { + ByteBuf out = ctx.channel().alloc().heapBuffer(4 + saslToken.length); + out.writeInt(saslToken.length); + out.writeBytes(saslToken, 0, saslToken.length); + + ctx.writeAndFlush(out).addListener(new ChannelFutureListener() { + @Override public void operationComplete(ChannelFuture future) throws Exception { + if (!future.isSuccess()) { + exceptionCaught(ctx, future.cause()); + } + } + }); + + saslToken = null; + } + } + } + + /** + * Handler for exceptions during Sasl connection + */ + public interface SaslExceptionHandler { + /** + * Handle the exception + * + * @param retryCount current retry count + * @param random to create new backoff with + * @param cause of fail + */ + public void handle(int retryCount, Random random, Throwable cause); + } + + /** + * Handler for successful connects + */ + public interface SaslSuccessfulConnectHandler { + /** + * Runs on success + * + * @param channel which is successfully authenticated + */ + public void onSuccess(Channel channel); + } +} \ No newline at end of file diff --git a/hbase-common/src/main/java/org/apache/hadoop/hbase/HConstants.java b/hbase-common/src/main/java/org/apache/hadoop/hbase/HConstants.java index 0ac3fbc..00c244b 100644 --- a/hbase-common/src/main/java/org/apache/hadoop/hbase/HConstants.java +++ b/hbase-common/src/main/java/org/apache/hadoop/hbase/HConstants.java @@ -62,7 +62,7 @@ public final class HConstants { /** * The first four bytes of Hadoop RPC connections */ - public static final ByteBuffer RPC_HEADER = ByteBuffer.wrap("HBas".getBytes()); + public static final byte[] RPC_HEADER = new byte[] { 'H', 'B', 'a', 's' }; public static final byte RPC_CURRENT_VERSION = 0; // HFileBlock constants. diff --git a/hbase-server/src/main/java/org/apache/hadoop/hbase/ipc/RpcServer.java b/hbase-server/src/main/java/org/apache/hadoop/hbase/ipc/RpcServer.java index e8194a6..8052af1 100644 --- a/hbase-server/src/main/java/org/apache/hadoop/hbase/ipc/RpcServer.java +++ b/hbase-server/src/main/java/org/apache/hadoop/hbase/ipc/RpcServer.java @@ -44,6 +44,7 @@ import java.nio.channels.SocketChannel; import java.nio.channels.WritableByteChannel; import java.security.PrivilegedExceptionAction; import java.util.ArrayList; +import java.util.Arrays; import java.util.Collections; import java.util.HashMap; import java.util.Iterator; @@ -1408,9 +1409,9 @@ public class RpcServer implements RpcServerInterface { int count; // Check for 'HBas' magic. this.dataLengthBuffer.flip(); - if (!HConstants.RPC_HEADER.equals(dataLengthBuffer)) { + if (!Arrays.equals(HConstants.RPC_HEADER, dataLengthBuffer.array())) { return doBadPreambleHandling("Expected HEADER=" + - Bytes.toStringBinary(HConstants.RPC_HEADER.array()) + + Bytes.toStringBinary(HConstants.RPC_HEADER) + " but received HEADER=" + Bytes.toStringBinary(dataLengthBuffer.array()) + " from " + toString()); } diff --git a/hbase-server/src/test/java/org/apache/hadoop/hbase/ipc/TestIPC.java b/hbase-server/src/test/java/org/apache/hadoop/hbase/ipc/TestIPC.java index 2c70eb4..a7ad616 100644 --- a/hbase-server/src/test/java/org/apache/hadoop/hbase/ipc/TestIPC.java +++ b/hbase-server/src/test/java/org/apache/hadoop/hbase/ipc/TestIPC.java @@ -33,9 +33,18 @@ import java.net.InetSocketAddress; import java.net.Socket; import java.util.ArrayList; import java.util.List; +import java.util.concurrent.atomic.AtomicBoolean; import javax.net.SocketFactory; +import com.google.protobuf.BlockingRpcChannel; +import com.google.protobuf.RpcCallback; +import com.google.protobuf.RpcChannel; +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.ChannelInitializer; +import io.netty.channel.ChannelOutboundHandlerAdapter; +import io.netty.channel.ChannelPromise; +import io.netty.channel.socket.SocketChannel; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; import org.apache.hadoop.conf.Configuration; @@ -44,10 +53,13 @@ import org.apache.hadoop.hbase.CellScannable; import org.apache.hadoop.hbase.CellScanner; import org.apache.hadoop.hbase.CellUtil; import org.apache.hadoop.hbase.HBaseConfiguration; +import org.apache.hadoop.hbase.HBaseTestingUtility; import org.apache.hadoop.hbase.HConstants; import org.apache.hadoop.hbase.HRegionInfo; import org.apache.hadoop.hbase.KeyValue; import org.apache.hadoop.hbase.KeyValueUtil; +import org.apache.hadoop.hbase.ServerName; +import org.apache.hadoop.hbase.Waiter; import org.apache.hadoop.hbase.testclassification.SmallTests; import org.apache.hadoop.hbase.client.Put; import org.apache.hadoop.hbase.client.RowMutations; @@ -91,6 +103,7 @@ import com.google.protobuf.ServiceException; @Category(SmallTests.class) public class TestIPC { public static final Log LOG = LogFactory.getLog(TestIPC.class); + private final static HBaseTestingUtility TEST_UTIL = new HBaseTestingUtility(); static byte [] CELL_BYTES = Bytes.toBytes("xyz"); static Cell CELL = new KeyValue(CELL_BYTES, CELL_BYTES, CELL_BYTES, CELL_BYTES); static byte [] BIG_CELL_BYTES = new byte [10 * 1024]; @@ -190,8 +203,8 @@ public class TestIPC { MethodDescriptor md = SERVICE.getDescriptorForType().findMethodByName("echo"); final String message = "hello"; EchoRequestProto param = EchoRequestProto.newBuilder().setMessage(message).build(); - Pair r = client.call(null, md, param, null, - md.getOutputType().toProto(), User.getCurrent(), address, 0); + Pair r = client.call(null, md, param, + md.getOutputType().toProto(), User.getCurrent(), address); assertTrue(r.getSecond() == null); // Silly assertion that the message is in the returned pb. assertTrue(r.getFirst().toString().contains(message)); @@ -202,6 +215,44 @@ public class TestIPC { } /** + * Ensure we do not HAVE TO HAVE a codec. + * + * @throws InterruptedException + * @throws IOException + */ + @Test public void testNoCodecAsync() throws InterruptedException, IOException, ServiceException { + Configuration conf = HBaseConfiguration.create(); + AsyncRpcClient client = new AsyncRpcClient(conf, HConstants.CLUSTER_ID_DEFAULT, null) { + @Override Codec getCodec() { + return null; + } + }; + TestRpcServer rpcServer = new TestRpcServer(); + try { + rpcServer.start(); + InetSocketAddress address = rpcServer.getListenerAddress(); + MethodDescriptor md = SERVICE.getDescriptorForType().findMethodByName("echo"); + final String message = "hello"; + EchoRequestProto param = EchoRequestProto.newBuilder().setMessage(message).build(); + + BlockingRpcChannel channel = client + .createBlockingRpcChannel(ServerName.valueOf(address.getHostName(), address.getPort(), + System.currentTimeMillis()), User.getCurrent(), 0); + + PayloadCarryingRpcController controller = new PayloadCarryingRpcController(); + Message response = + channel.callBlockingMethod(md, controller, param, md.getOutputType().toProto()); + + assertTrue(controller.cellScanner() == null); + // Silly assertion that the message is in the returned pb. + assertTrue(response.toString().contains(message)); + } finally { + client.close(); + rpcServer.stop(); + } + } + + /** * It is hard to verify the compression is actually happening under the wraps. Hope that if * unsupported, we'll get an exception out of some time (meantime, have to trace it manually * to confirm that compression is happening down in the client and server). @@ -212,13 +263,17 @@ public class TestIPC { */ @Test public void testCompressCellBlock() - throws IOException, InterruptedException, SecurityException, NoSuchMethodException { + throws IOException, InterruptedException, SecurityException, NoSuchMethodException, + ServiceException { Configuration conf = new Configuration(HBaseConfiguration.create()); conf.set("hbase.client.rpc.compressor", GzipCodec.class.getCanonicalName()); - doSimpleTest(conf, new RpcClientImpl(conf, HConstants.CLUSTER_ID_DEFAULT)); + doSimpleTest(new RpcClientImpl(conf, HConstants.CLUSTER_ID_DEFAULT)); + + // Another test for the async client + doAsyncSimpleTest(new AsyncRpcClient(conf, HConstants.CLUSTER_ID_DEFAULT, null)); } - private void doSimpleTest(final Configuration conf, final RpcClientImpl client) + private void doSimpleTest(final RpcClientImpl client) throws InterruptedException, IOException { TestRpcServer rpcServer = new TestRpcServer(); List cells = new ArrayList(); @@ -229,8 +284,11 @@ public class TestIPC { InetSocketAddress address = rpcServer.getListenerAddress(); MethodDescriptor md = SERVICE.getDescriptorForType().findMethodByName("echo"); EchoRequestProto param = EchoRequestProto.newBuilder().setMessage("hello").build(); - Pair r = client.call(null, md, param, CellUtil.createCellScanner(cells), - md.getOutputType().toProto(), User.getCurrent(), address, 0); + + PayloadCarryingRpcController pcrc = + new PayloadCarryingRpcController(CellUtil.createCellScanner(cells)); + Pair r = client + .call(pcrc, md, param, md.getOutputType().toProto(), User.getCurrent(), address); int index = 0; while (r.getSecond().advance()) { assertTrue(CELL.equals(r.getSecond().current())); @@ -243,6 +301,42 @@ public class TestIPC { } } + private void doAsyncSimpleTest(final AsyncRpcClient client) + throws InterruptedException, IOException, ServiceException { + TestRpcServer rpcServer = new TestRpcServer(); + List cells = new ArrayList(); + int count = 3; + for (int i = 0; i < count; i++) + cells.add(CELL); + try { + rpcServer.start(); + InetSocketAddress address = rpcServer.getListenerAddress(); + MethodDescriptor md = SERVICE.getDescriptorForType().findMethodByName("echo"); + EchoRequestProto param = EchoRequestProto.newBuilder().setMessage("hello").build(); + + PayloadCarryingRpcController pcrc = + new PayloadCarryingRpcController(CellUtil.createCellScanner(cells)); + + BlockingRpcChannel channel = client.createBlockingRpcChannel( + ServerName.valueOf(address.getHostName(), address.getPort(), System.currentTimeMillis()), + User.getCurrent(), 0); + + channel.callBlockingMethod(md, pcrc, param, md.getOutputType().toProto()); + + CellScanner cellScanner = pcrc.cellScanner(); + + int index = 0; + while (cellScanner.advance()) { + assertTrue(CELL.equals(cellScanner.current())); + index++; + } + assertEquals(count, index); + } finally { + client.close(); + rpcServer.stop(); + } + } + @Test public void testRTEDuringConnectionSetup() throws Exception { Configuration conf = HBaseConfiguration.create(); @@ -263,7 +357,48 @@ public class TestIPC { InetSocketAddress address = rpcServer.getListenerAddress(); MethodDescriptor md = SERVICE.getDescriptorForType().findMethodByName("echo"); EchoRequestProto param = EchoRequestProto.newBuilder().setMessage("hello").build(); - client.call(null, md, param, null, null, User.getCurrent(), address, 0); + client.call(null, md, param, null, User.getCurrent(), address); + fail("Expected an exception to have been thrown!"); + } catch (Exception e) { + LOG.info("Caught expected exception: " + e.toString()); + assertTrue(StringUtils.stringifyException(e).contains("Injected fault")); + } finally { + client.close(); + rpcServer.stop(); + } + } + + @Test + public void testRTEDuringAsyncBlockingConnectionSetup() throws Exception { + Configuration conf = HBaseConfiguration.create(); + + TestRpcServer rpcServer = new TestRpcServer(); + AsyncRpcClient client = new AsyncRpcClient(conf, HConstants.CLUSTER_ID_DEFAULT, null, + new ChannelInitializer() { + + @Override protected void initChannel(SocketChannel ch) throws Exception { + ch.pipeline().addFirst(new ChannelOutboundHandlerAdapter() { + @Override + public void write(ChannelHandlerContext ctx, Object msg, ChannelPromise promise) + throws Exception { + promise.setFailure(new RuntimeException("Injected fault")); + } + }); + } + }); + try { + rpcServer.start(); + InetSocketAddress address = rpcServer.getListenerAddress(); + MethodDescriptor md = SERVICE.getDescriptorForType().findMethodByName("echo"); + EchoRequestProto param = EchoRequestProto.newBuilder().setMessage("hello").build(); + + BlockingRpcChannel channel = client.createBlockingRpcChannel( + ServerName.valueOf(address.getHostName(), address.getPort(), System.currentTimeMillis()), + User.getCurrent(), 0); + + channel.callBlockingMethod(md, new PayloadCarryingRpcController(), param, + md.getOutputType().toProto()); + fail("Expected an exception to have been thrown!"); } catch (Exception e) { LOG.info("Caught expected exception: " + e.toString()); @@ -274,6 +409,106 @@ public class TestIPC { } } + + @Test + public void testRTEDuringAsyncConnectionSetup() throws Exception { + Configuration conf = HBaseConfiguration.create(); + + TestRpcServer rpcServer = new TestRpcServer(); + AsyncRpcClient client = new AsyncRpcClient(conf, HConstants.CLUSTER_ID_DEFAULT, null, + new ChannelInitializer() { + + @Override protected void initChannel(SocketChannel ch) throws Exception { + ch.pipeline().addFirst(new ChannelOutboundHandlerAdapter() { + @Override + public void write(ChannelHandlerContext ctx, Object msg, ChannelPromise promise) + throws Exception { + promise.setFailure(new RuntimeException("Injected fault")); + } + }); + } + }); + try { + rpcServer.start(); + InetSocketAddress address = rpcServer.getListenerAddress(); + MethodDescriptor md = SERVICE.getDescriptorForType().findMethodByName("echo"); + EchoRequestProto param = EchoRequestProto.newBuilder().setMessage("hello").build(); + + RpcChannel channel = client.createRpcChannel( + ServerName.valueOf(address.getHostName(), address.getPort(), System.currentTimeMillis()), + User.getCurrent(), 0); + + final AtomicBoolean done = new AtomicBoolean(false); + + PayloadCarryingRpcController controller = new PayloadCarryingRpcController(); + controller.notifyOnFail(new RpcCallback() { + @Override + public void run(IOException e) { + done.set(true); + LOG.info("Caught expected exception: " + e.toString()); + assertTrue(StringUtils.stringifyException(e).contains("Injected fault")); + } + }); + + channel.callMethod(md, controller, param, + md.getOutputType().toProto(), new RpcCallback() { + @Override + public void run(Message parameter) { + done.set(true); + fail("Expected an exception to have been thrown!"); + } + }); + + TEST_UTIL.waitFor(1000, new Waiter.Predicate() { + @Override + public boolean evaluate() throws Exception { + return done.get(); + } + }); + } finally { + client.close(); + rpcServer.stop(); + } + } + + @Test + public void testAsyncConnectionSetup() throws Exception { + Configuration conf = HBaseConfiguration.create(); + + TestRpcServer rpcServer = new TestRpcServer(); + AsyncRpcClient client = new AsyncRpcClient(conf, HConstants.CLUSTER_ID_DEFAULT, null); + try { + rpcServer.start(); + InetSocketAddress address = rpcServer.getListenerAddress(); + MethodDescriptor md = SERVICE.getDescriptorForType().findMethodByName("echo"); + EchoRequestProto param = EchoRequestProto.newBuilder().setMessage("hello").build(); + + RpcChannel channel = client.createRpcChannel( + ServerName.valueOf(address.getHostName(), address.getPort(), System.currentTimeMillis()), + User.getCurrent(), 0); + + final AtomicBoolean done = new AtomicBoolean(false); + + channel.callMethod(md, new PayloadCarryingRpcController(), param, + md.getOutputType().toProto(), new RpcCallback() { + @Override + public void run(Message parameter) { + done.set(true); + } + }); + + TEST_UTIL.waitFor(1000, new Waiter.Predicate() { + @Override + public boolean evaluate() throws Exception { + return done.get(); + } + }); + } finally { + client.close(); + rpcServer.stop(); + } + } + /** Tests that the rpc scheduler is called when requests arrive. */ @Test public void testRpcScheduler() throws IOException, InterruptedException { @@ -287,8 +522,43 @@ public class TestIPC { MethodDescriptor md = SERVICE.getDescriptorForType().findMethodByName("echo"); EchoRequestProto param = EchoRequestProto.newBuilder().setMessage("hello").build(); for (int i = 0; i < 10; i++) { - client.call(null, md, param, CellUtil.createCellScanner(ImmutableList.of(CELL)), - md.getOutputType().toProto(), User.getCurrent(), rpcServer.getListenerAddress(), 0); + client.call( + new PayloadCarryingRpcController(CellUtil.createCellScanner(ImmutableList.of(CELL))), + md, param, md.getOutputType().toProto(), User.getCurrent(), + rpcServer.getListenerAddress()); + } + verify(scheduler, times(10)).dispatch((CallRunner) anyObject()); + } finally { + rpcServer.stop(); + verify(scheduler).stop(); + } + } + + /** + * Tests that the rpc scheduler is called when requests arrive. + */ + @Test + public void testRpcSchedulerAsync() + throws IOException, InterruptedException, ServiceException { + RpcScheduler scheduler = spy(new FifoRpcScheduler(CONF, 1)); + RpcServer rpcServer = new TestRpcServer(scheduler); + verify(scheduler).init((RpcScheduler.Context) anyObject()); + AbstractRpcClient client = new AsyncRpcClient(CONF, HConstants.CLUSTER_ID_DEFAULT, null); + try { + rpcServer.start(); + verify(scheduler).start(); + MethodDescriptor md = SERVICE.getDescriptorForType().findMethodByName("echo"); + EchoRequestProto param = EchoRequestProto.newBuilder().setMessage("hello").build(); + ServerName serverName = ServerName.valueOf(rpcServer.getListenerAddress().getHostName(), + rpcServer.getListenerAddress().getPort(), System.currentTimeMillis()); + + for (int i = 0; i < 10; i++) { + BlockingRpcChannel channel = client.createBlockingRpcChannel( + serverName, User.getCurrent(), 0); + + channel.callBlockingMethod(md, + new PayloadCarryingRpcController(CellUtil.createCellScanner(ImmutableList.of(CELL))), + param, md.getOutputType().toProto()); } verify(scheduler, times(10)).dispatch((CallRunner) anyObject()); } finally { @@ -340,9 +610,10 @@ public class TestIPC { // ReflectionUtils.printThreadInfo(new PrintWriter(System.out), // "Thread dump " + Thread.currentThread().getName()); } - CellScanner cellScanner = CellUtil.createCellScanner(cells); + PayloadCarryingRpcController pcrc = + new PayloadCarryingRpcController(CellUtil.createCellScanner(cells)); Pair response = - client.call(null, md, builder.build(), cellScanner, param, user, address, 0); + client.call(pcrc, md, builder.build(), param, user, address); /* int count = 0; while (p.getSecond().advance()) { diff --git a/hbase-server/src/test/java/org/apache/hadoop/hbase/ipc/TestRpcHandlerException.java b/hbase-server/src/test/java/org/apache/hadoop/hbase/ipc/TestRpcHandlerException.java index 9cb1cc5..26510a8 100644 --- a/hbase-server/src/test/java/org/apache/hadoop/hbase/ipc/TestRpcHandlerException.java +++ b/hbase-server/src/test/java/org/apache/hadoop/hbase/ipc/TestRpcHandlerException.java @@ -17,13 +17,10 @@ */ package org.apache.hadoop.hbase.ipc; -import static org.mockito.Mockito.mock; - import java.io.IOException; import java.net.InetSocketAddress; import java.util.ArrayList; import java.util.List; - import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; import org.apache.hadoop.conf.Configuration; @@ -168,7 +165,7 @@ public class TestRpcHandlerException { @Ignore @Test public void testRpcScheduler() throws IOException, InterruptedException { - PriorityFunction qosFunction = mock(PriorityFunction.class); + PriorityFunction qosFunction = Mockito.mock(PriorityFunction.class); Abortable abortable = new AbortServer(); RpcScheduler scheduler = new SimpleRpcScheduler(CONF, 2, 0, 0, qosFunction, abortable, 0); RpcServer rpcServer = new TestRpcServer(scheduler); @@ -177,8 +174,10 @@ public class TestRpcHandlerException { rpcServer.start(); MethodDescriptor md = SERVICE.getDescriptorForType().findMethodByName("echo"); EchoRequestProto param = EchoRequestProto.newBuilder().setMessage("hello").build(); - client.call(null, md, param, CellUtil.createCellScanner(ImmutableList.of(CELL)), md - .getOutputType().toProto(), User.getCurrent(), rpcServer.getListenerAddress(), 0); + PayloadCarryingRpcController controller = + new PayloadCarryingRpcController(CellUtil.createCellScanner(ImmutableList.of(CELL))); + client.call(controller, md, param, md.getOutputType().toProto(), User.getCurrent(), + rpcServer.getListenerAddress()); } catch (Throwable e) { assert(abortable.isAborted() == true); } finally { -- 2.2.1