diff --git a/hbase-server/src/main/java/org/apache/hadoop/hbase/ipc/BufferChain.java b/hbase-server/src/main/java/org/apache/hadoop/hbase/ipc/BufferChain.java index 26bc56c..bd0515a 100644 --- a/hbase-server/src/main/java/org/apache/hadoop/hbase/ipc/BufferChain.java +++ b/hbase-server/src/main/java/org/apache/hadoop/hbase/ipc/BufferChain.java @@ -32,11 +32,13 @@ class BufferChain { private final ByteBuffer[] buffers; private int remaining = 0; private int bufferOffset = 0; + private int size; BufferChain(ByteBuffer[] buffers) { for (ByteBuffer b : buffers) { this.remaining += b.remaining(); } + this.size = remaining; this.buffers = buffers; } @@ -108,4 +110,12 @@ class BufferChain { } } } + + int size() { + return size; + } + + ByteBuffer[] getBuffers() { + return this.buffers; + } } diff --git a/hbase-server/src/main/java/org/apache/hadoop/hbase/ipc/NettyRpcServer.java b/hbase-server/src/main/java/org/apache/hadoop/hbase/ipc/NettyRpcServer.java new file mode 100644 index 0000000..b7ae8c4 --- /dev/null +++ b/hbase-server/src/main/java/org/apache/hadoop/hbase/ipc/NettyRpcServer.java @@ -0,0 +1,1217 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.hadoop.hbase.ipc; + +import io.netty.bootstrap.ServerBootstrap; +import io.netty.buffer.ByteBuf; +import io.netty.buffer.PooledByteBufAllocator; +import io.netty.buffer.Unpooled; +import io.netty.channel.Channel; +import io.netty.channel.ChannelFuture; +import io.netty.channel.ChannelFutureListener; +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.ChannelInboundHandlerAdapter; +import io.netty.channel.ChannelInitializer; +import io.netty.channel.ChannelOption; +import io.netty.channel.ChannelOutboundHandlerAdapter; +import io.netty.channel.ChannelPipeline; +import io.netty.channel.ChannelPromise; +import io.netty.channel.EventLoopGroup; +import io.netty.channel.epoll.EpollEventLoopGroup; +import io.netty.channel.epoll.EpollServerSocketChannel; +import io.netty.channel.group.ChannelGroup; +import io.netty.channel.group.DefaultChannelGroup; +import io.netty.channel.nio.NioEventLoopGroup; +import io.netty.channel.socket.SocketChannel; +import io.netty.channel.socket.nio.NioServerSocketChannel; +import io.netty.handler.codec.ByteToMessageDecoder; +import io.netty.handler.codec.DecoderException; +import io.netty.handler.codec.ReplayingDecoder; +import io.netty.util.concurrent.GlobalEventExecutor; +import io.netty.util.internal.RecyclableArrayList; +import io.netty.util.internal.StringUtil; + +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; +import java.io.DataOutputStream; +import java.io.IOException; +import java.net.InetAddress; +import java.net.InetSocketAddress; +import java.nio.ByteBuffer; +import java.nio.channels.Channels; +import java.nio.channels.ReadableByteChannel; +import java.security.PrivilegedExceptionAction; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; +import java.util.concurrent.CountDownLatch; + +import javax.security.sasl.Sasl; +import javax.security.sasl.SaslException; + +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; +import org.apache.hadoop.conf.Configuration; +import org.apache.hadoop.hbase.CellScanner; +import org.apache.hadoop.hbase.DoNotRetryIOException; +import org.apache.hadoop.hbase.HConstants; +import org.apache.hadoop.hbase.Server; +import org.apache.hadoop.hbase.classification.InterfaceStability; +import org.apache.hadoop.hbase.client.VersionInfoUtil; +import org.apache.hadoop.hbase.io.ByteBufferOutputStream; +import org.apache.hadoop.hbase.monitoring.MonitoredRPCHandler; +import org.apache.hadoop.hbase.nio.ByteBuff; +import org.apache.hadoop.hbase.nio.SingleByteBuff; +import org.apache.hadoop.hbase.security.AccessDeniedException; +import org.apache.hadoop.hbase.security.AuthMethod; +import org.apache.hadoop.hbase.security.HBasePolicyProvider; +import org.apache.hadoop.hbase.security.HBaseSaslRpcServer; +import org.apache.hadoop.hbase.security.HBaseSaslRpcServer.SaslDigestCallbackHandler; +import org.apache.hadoop.hbase.security.HBaseSaslRpcServer.SaslGssCallbackHandler; +import org.apache.hadoop.hbase.security.SaslStatus; +import org.apache.hadoop.hbase.security.SaslUtil; +import org.apache.hadoop.hbase.shaded.com.google.protobuf.BlockingService; +import org.apache.hadoop.hbase.shaded.com.google.protobuf.CodedInputStream; +import org.apache.hadoop.hbase.shaded.com.google.protobuf.Descriptors.MethodDescriptor; +import org.apache.hadoop.hbase.shaded.com.google.protobuf.Message; +import org.apache.hadoop.hbase.shaded.com.google.protobuf.TextFormat; +import org.apache.hadoop.hbase.shaded.com.google.protobuf.UnsafeByteOperations; +import org.apache.hadoop.hbase.shaded.protobuf.ProtobufUtil; +import org.apache.hadoop.hbase.shaded.protobuf.generated.RPCProtos; +import org.apache.hadoop.hbase.shaded.protobuf.generated.RPCProtos.ConnectionHeader; +import org.apache.hadoop.hbase.shaded.protobuf.generated.RPCProtos.RequestHeader; +import org.apache.hadoop.hbase.util.Bytes; +import org.apache.hadoop.hbase.util.JVM; +import org.apache.hadoop.hbase.util.Pair; +import org.apache.hadoop.io.BytesWritable; +import org.apache.hadoop.io.IntWritable; +import org.apache.hadoop.io.Writable; +import org.apache.hadoop.io.WritableUtils; +import org.apache.hadoop.security.UserGroupInformation; +import org.apache.hadoop.security.UserGroupInformation.AuthenticationMethod; +import org.apache.hadoop.security.authorize.AuthorizationException; +import org.apache.hadoop.security.authorize.ProxyUsers; +import org.apache.hadoop.security.authorize.ServiceAuthorizationManager; +import org.apache.hadoop.security.token.SecretManager.InvalidToken; +import org.apache.htrace.TraceInfo; + +/** + * An RPC server with Netty4 implementation. + * + */ +public class NettyRpcServer extends RpcServer { + + public static final Log LOG = LogFactory.getLog(NettyRpcServer.class); + + protected final InetSocketAddress bindAddress; + + private final CountDownLatch closed = new CountDownLatch(1); + private final Channel serverChannel; + private final ChannelGroup allChannels = new DefaultChannelGroup(GlobalEventExecutor.INSTANCE);; + + public NettyRpcServer(final Server server, final String name, + final List services, + final InetSocketAddress bindAddress, Configuration conf, + RpcScheduler scheduler) throws IOException { + super(server, name, services, bindAddress, conf, scheduler); + this.bindAddress = bindAddress; + boolean useEpoll = useEpoll(conf); + int workerCount = conf.getInt("hbase.rpc.server.netty.work.size", 12); + LOG.info("hbase.rpc.server.netty.work.size: " + workerCount); + EventLoopGroup bossGroup = null; + EventLoopGroup workerGroup = null; + if (useEpoll) { + bossGroup = new EpollEventLoopGroup(2); + workerGroup = new EpollEventLoopGroup(workerCount); + } else { + bossGroup = new NioEventLoopGroup(2); + workerGroup = new NioEventLoopGroup(workerCount); + } + ServerBootstrap bootstrap = new ServerBootstrap(); + bootstrap.group(bossGroup, workerGroup); + if (useEpoll) { + bootstrap.channel(EpollServerSocketChannel.class); + } else { + bootstrap.channel(NioServerSocketChannel.class); + } + bootstrap.childOption(ChannelOption.TCP_NODELAY, true); + bootstrap.childOption(ChannelOption.SO_KEEPALIVE, true); + bootstrap.childOption(ChannelOption.SO_LINGER, 0); + bootstrap.childOption(ChannelOption.ALLOCATOR, + PooledByteBufAllocator.DEFAULT); + bootstrap.childHandler(new Initializer()); + + try { + serverChannel = bootstrap.bind(this.bindAddress).sync().channel(); + LOG.info("NettyRpcServer bind to address=" + serverChannel.localAddress()); + allChannels.add(serverChannel); + } catch (InterruptedException e) { + throw new IOException(e); + } + initReconfigurable(conf); + this.scheduler.init(new RpcSchedulerContext(this)); + } + + private boolean useEpoll(Configuration conf) { + // Config to enable native transport. + boolean epollEnabled = conf.getBoolean("hbase.rpc.server.nativetransport", true); + // Use the faster native epoll transport mechanism on linux if enabled + if (epollEnabled && JVM.isLinux() && JVM.isAmd64()) { + return true; + } else { + return false; + } + } + + @Override + public void start() { + if (started) { + return; + } + authTokenSecretMgr = createSecretManager(); + if (authTokenSecretMgr != null) { + setSecretManager(authTokenSecretMgr); + authTokenSecretMgr.start(); + } + this.authManager = new ServiceAuthorizationManager(); + HBasePolicyProvider.init(conf, authManager); + scheduler.start(); + started = true; + } + + @Override + public void stop() { + LOG.info("Stopping server on " + this.bindAddress.getPort()); + if (authTokenSecretMgr != null) { + authTokenSecretMgr.stop(); + authTokenSecretMgr = null; + } + allChannels.close().awaitUninterruptibly(); + serverChannel.close(); + scheduler.stop(); + closed.countDown(); + } + + @Override + public void join() throws InterruptedException { + closed.await(); + } + + @Override + public InetSocketAddress getListenerAddress() { + return ((InetSocketAddress) serverChannel.localAddress()); + } + + private void setupResponse(ByteArrayOutputStream response, Call call, Throwable t, String error) + throws IOException { + if (response != null) response.reset(); + call.setResponse(null, null, t, error); + } + + class Connection extends RpcServer.Connection { + + private final Call authFailedCall = new Call(AUTHORIZATION_FAILED_CALLID, + null, null, null, null, null, this, 0, null, null, 0, null); + + private final Call saslCall = new Call(SASL_CALLID, null, null, null, null, + null, this, 0, null, null, 0, null); + + private final Call setConnectionHeaderResponseCall = new Call( + CONNECTION_HEADER_RESPONSE_CALLID, null, null, null, null, null, this, + 0, null, null, 0, null); + + private ByteBuffer unwrappedData; + // When is this set? FindBugs wants to know! Says NP + private ByteBuffer unwrappedDataLengthBuffer = ByteBuffer.allocate(4); + protected Channel channel; + + Connection(Channel channel) { + super(); + this.channel = channel; + InetSocketAddress inetSocketAddress = ((InetSocketAddress) channel.remoteAddress()); + this.addr = inetSocketAddress.getAddress(); + if (addr == null) { + this.hostAddress = "*Unknown*"; + } else { + this.hostAddress = inetSocketAddress.getAddress().getHostAddress(); + } + this.remotePort = inetSocketAddress.getPort(); + } + + void readPreamble(ByteBuf buffer) throws IOException { + byte[] rpcHead = + { buffer.readByte(), buffer.readByte(), buffer.readByte(), buffer.readByte() }; + if (!Arrays.equals(HConstants.RPC_HEADER, rpcHead)) { + doBadPreambleHandling("Expected HEADER=" + + Bytes.toStringBinary(HConstants.RPC_HEADER) + " but received HEADER=" + + Bytes.toStringBinary(rpcHead) + " from " + toString()); + return; + } + // Now read the next two bytes, the version and the auth to use. + int version = buffer.readByte(); + byte authbyte = buffer.readByte(); + this.authMethod = AuthMethod.valueOf(authbyte); + if (version != CURRENT_VERSION) { + String msg = getFatalConnectionString(version, authbyte); + doBadPreambleHandling(msg, new WrongVersionException(msg)); + return; + } + if (authMethod == null) { + String msg = getFatalConnectionString(version, authbyte); + doBadPreambleHandling(msg, new BadAuthException(msg)); + return; + } + if (isSecurityEnabled && authMethod == AuthMethod.SIMPLE) { + if (allowFallbackToSimpleAuth) { + metrics.authenticationFallback(); + authenticatedWithFallback = true; + } else { + AccessDeniedException ae = new AccessDeniedException( + "Authentication is required"); + setupResponse(authFailedResponse, authFailedCall, ae, ae.getMessage()); + authFailedCall.sendResponseIfReady(ChannelFutureListener.CLOSE); + return; + } + } + if (!isSecurityEnabled && authMethod != AuthMethod.SIMPLE) { + doRawSaslReply(SaslStatus.SUCCESS, new IntWritable(SaslUtil.SWITCH_TO_SIMPLE_AUTH), null, + null); + authMethod = AuthMethod.SIMPLE; + // client has already sent the initial Sasl message and we + // should ignore it. Both client and server should fall back + // to simple auth from now on. + skipInitialSaslHandshake = true; + } + if (authMethod != AuthMethod.SIMPLE) { + useSasl = true; + } + } + + /** + * No protobuf encoding of raw sasl messages + */ + private void doRawSaslReply(SaslStatus status, Writable rv, + String errorClass, String error) throws IOException { + ByteBufferOutputStream saslResponse = null; + DataOutputStream out = null; + try { + // In my testing, have noticed that sasl messages are usually + // in the ballpark of 100-200. That's why the initial capacity is 256. + saslResponse = new ByteBufferOutputStream(256); + out = new DataOutputStream(saslResponse); + out.writeInt(status.state); // write status + if (status == SaslStatus.SUCCESS) { + rv.write(out); + } else { + WritableUtils.writeString(out, errorClass); + WritableUtils.writeString(out, error); + } + saslCall.setSaslTokenResponse(saslResponse.getByteBuffer()); + saslCall.sendResponseIfReady(); + } finally { + if (saslResponse != null) { + saslResponse.close(); + } + if (out != null) { + out.close(); + } + } + } + + private void doBadPreambleHandling(final String msg) throws IOException { + doBadPreambleHandling(msg, new FatalConnectionException(msg)); + } + + private void doBadPreambleHandling(final String msg, final Exception e) throws IOException { + LOG.warn(msg); + Call fakeCall = new Call(-1, null, null, null, null, null, this, -1, + null, null, 0, null); + setupResponse(null, fakeCall, e, msg); + // closes out the connection. + fakeCall.sendResponseIfReady(ChannelFutureListener.CLOSE); + } + + void process(ByteBuffer buf) throws IOException, InterruptedException { + process(new SingleByteBuff(buf)); + } + + void process(ByteBuff buf) throws IOException, InterruptedException { + if (skipInitialSaslHandshake) { + skipInitialSaslHandshake = false; + return; + } + + if (useSasl) { + saslReadAndProcess(buf); + } else { + processOneRpc(buf); + } + } + + private void processOneRpc(ByteBuff buf) throws IOException, InterruptedException { + if (connectionHeaderRead) { + processRequest(buf); + } else { + processConnectionHeader(buf); + this.connectionHeaderRead = true; + if (!authorizeConnection()) { + // Throw FatalConnectionException wrapping ACE so client does right thing and closes + // down the connection instead of trying to read non-existent retun. + throw new AccessDeniedException("Connection from " + this + " for service " + + connectionHeader.getServiceName() + " is unauthorized for user: " + ugi); + } + this.user = userProvider.create(this.ugi); + } + } + + // Reads the connection header following version + private void processConnectionHeader(ByteBuff buf) throws IOException { + if (buf.hasArray()) { + this.connectionHeader = ConnectionHeader.parseFrom(buf.array()); + } else { + CodedInputStream cis = UnsafeByteOperations + .unsafeWrap(new ByteBuffByteInput(buf, 0, buf.limit()), 0, buf.limit()).newCodedInput(); + cis.enableAliasing(true); + this.connectionHeader = ConnectionHeader.parseFrom(cis); + } + String serviceName = connectionHeader.getServiceName(); + if (serviceName == null) throw new EmptyServiceNameException(); + this.service = getService(services, serviceName); + if (this.service == null) throw new UnknownServiceException(serviceName); + setupCellBlockCodecs(this.connectionHeader); + RPCProtos.ConnectionHeaderResponse.Builder chrBuilder = + RPCProtos.ConnectionHeaderResponse.newBuilder(); + setupCryptoCipher(this.connectionHeader, chrBuilder); + responseConnectionHeader(chrBuilder); + UserGroupInformation protocolUser = createUser(connectionHeader); + if (!useSasl) { + ugi = protocolUser; + if (ugi != null) { + ugi.setAuthenticationMethod(AuthMethod.SIMPLE.authenticationMethod); + } + // audit logging for SASL authenticated users happens in saslReadAndProcess() + if (authenticatedWithFallback) { + LOG.warn("Allowed fallback to SIMPLE auth for " + ugi + + " connecting from " + getHostAddress()); + } + AUDITLOG.info(AUTH_SUCCESSFUL_FOR + ugi); + } else { + // user is authenticated + ugi.setAuthenticationMethod(authMethod.authenticationMethod); + //Now we check if this is a proxy user case. If the protocol user is + //different from the 'user', it is a proxy user scenario. However, + //this is not allowed if user authenticated with DIGEST. + if ((protocolUser != null) + && (!protocolUser.getUserName().equals(ugi.getUserName()))) { + if (authMethod == AuthMethod.DIGEST) { + // Not allowed to doAs if token authentication is used + throw new AccessDeniedException("Authenticated user (" + ugi + + ") doesn't match what the client claims to be (" + + protocolUser + ")"); + } else { + // Effective user can be different from authenticated user + // for simple auth or kerberos auth + // The user is the real user. Now we create a proxy user + UserGroupInformation realUser = ugi; + ugi = UserGroupInformation.createProxyUser(protocolUser + .getUserName(), realUser); + // Now the user is a proxy user, set Authentication method Proxy. + ugi.setAuthenticationMethod(AuthenticationMethod.PROXY); + } + } + } + if (connectionHeader.hasVersionInfo()) { + // see if this connection will support RetryImmediatelyException + retryImmediatelySupported = VersionInfoUtil.hasMinimumVersion(getVersionInfo(), 1, 2); + + AUDITLOG.info("Connection from " + this.hostAddress + " port: " + this.remotePort + + " with version info: " + + TextFormat.shortDebugString(connectionHeader.getVersionInfo())); + } else { + AUDITLOG.info("Connection from " + this.hostAddress + " port: " + this.remotePort + + " with unknown version info"); + } + } + + private void responseConnectionHeader(RPCProtos.ConnectionHeaderResponse.Builder chrBuilder) + throws FatalConnectionException { + // Response the connection header if Crypto AES is enabled + if (!chrBuilder.hasCryptoCipherMeta()) return; + try { + byte[] connectionHeaderResBytes = chrBuilder.build().toByteArray(); + // encrypt the Crypto AES cipher meta data with sasl server, and send to client + byte[] unwrapped = new byte[connectionHeaderResBytes.length + 4]; + Bytes.putBytes(unwrapped, 0, Bytes.toBytes(connectionHeaderResBytes.length), 0, 4); + Bytes.putBytes(unwrapped, 4, connectionHeaderResBytes, 0, connectionHeaderResBytes.length); + + doConnectionHeaderResponse(saslServer.wrap(unwrapped, 0, unwrapped.length)); + } catch (IOException ex) { + throw new UnsupportedCryptoException(ex.getMessage(), ex); + } + } + + /** + * Send the response for connection header + */ + private void doConnectionHeaderResponse(byte[] wrappedCipherMetaData) + throws IOException { + ByteBufferOutputStream response = null; + DataOutputStream out = null; + try { + response = new ByteBufferOutputStream(wrappedCipherMetaData.length + 4); + out = new DataOutputStream(response); + out.writeInt(wrappedCipherMetaData.length); + out.write(wrappedCipherMetaData); + + setConnectionHeaderResponseCall.setConnectionHeaderResponse(response + .getByteBuffer()); + setConnectionHeaderResponseCall.sendResponseIfReady(); + } finally { + if (out != null) { + out.close(); + } + if (response != null) { + response.close(); + } + } + } + + /** + * @param buf Has the request header and the request param and optionally encoded data buffer + * all in this one array. + * @throws IOException + * @throws InterruptedException + */ + void processRequest(ByteBuff buf) throws IOException, + InterruptedException { + long totalRequestSize = buf.limit(); + int offset = 0; + // Here we read in the header. We avoid having pb + // do its default 4k allocation for CodedInputStream. We force it to use backing array. + CodedInputStream cis; + if (buf.hasArray()) { + cis = UnsafeByteOperations.unsafeWrap(buf.array(), 0, buf.limit()).newCodedInput(); + } else { + cis = UnsafeByteOperations + .unsafeWrap(new ByteBuffByteInput(buf, 0, buf.limit()), 0, buf.limit()).newCodedInput(); + } + cis.enableAliasing(true); + int headerSize = cis.readRawVarint32(); + offset = cis.getTotalBytesRead(); + Message.Builder builder = RequestHeader.newBuilder(); + ProtobufUtil.mergeFrom(builder, cis, headerSize); + RequestHeader header = (RequestHeader) builder.build(); + offset += headerSize; + int id = header.getCallId(); + if (LOG.isTraceEnabled()) { + LOG.trace("RequestHeader " + TextFormat.shortDebugString(header) + + " totalRequestSize: " + totalRequestSize + " bytes"); + } + // Enforcing the call queue size, this triggers a retry in the client + // This is a bit late to be doing this check - we have already read in the total request. + if ((totalRequestSize + callQueueSizeInBytes.sum()) > maxQueueSizeInBytes) { + final Call callTooBig = new Call(id, this.service, null, null, null, + null, this, totalRequestSize, null, null, 0, null); + ByteArrayOutputStream responseBuffer = new ByteArrayOutputStream(); + metrics.exception(CALL_QUEUE_TOO_BIG_EXCEPTION); + setupResponse(responseBuffer, callTooBig, CALL_QUEUE_TOO_BIG_EXCEPTION, + "Call queue is full on " + getListenerAddress() + + ", is hbase.ipc.server.max.callqueue.size too small?"); + callTooBig.sendResponseIfReady(); + return; + } + MethodDescriptor md = null; + Message param = null; + CellScanner cellScanner = null; + try { + if (header.hasRequestParam() && header.getRequestParam()) { + md = this.service.getDescriptorForType().findMethodByName(header.getMethodName()); + if (md == null) throw new UnsupportedOperationException(header.getMethodName()); + builder = this.service.getRequestPrototype(md).newBuilderForType(); + cis.resetSizeCounter(); + int paramSize = cis.readRawVarint32(); + offset += cis.getTotalBytesRead(); + if (builder != null) { + ProtobufUtil.mergeFrom(builder, cis, paramSize); + param = builder.build(); + } + offset += paramSize; + } else { + // currently header must have request param, so we directly throw exception here + String msg = "Invalid request header: " + TextFormat.shortDebugString(header) + + ", should have param set in it"; + LOG.warn(msg); + throw new DoNotRetryIOException(msg); + } + if (header.hasCellBlockMeta()) { + buf.position(offset); + ByteBuff dup = buf.duplicate(); + dup.limit(offset + header.getCellBlockMeta().getLength()); + cellScanner = cellBlockBuilder.createCellScannerReusingBuffers( + this.codec, this.compressionCodec, dup); + } + } catch (Throwable t) { + InetSocketAddress address = getListenerAddress(); + String msg = (address != null ? address : "(channel closed)") + + " is unable to read call parameter from client " + getHostAddress(); + LOG.warn(msg, t); + + metrics.exception(t); + + // probably the hbase hadoop version does not match the running hadoop version + if (t instanceof LinkageError) { + t = new DoNotRetryIOException(t); + } + // If the method is not present on the server, do not retry. + if (t instanceof UnsupportedOperationException) { + t = new DoNotRetryIOException(t); + } + + final Call readParamsFailedCall = new Call(id, this.service, null, + null, null, null, this, totalRequestSize, null, null, 0, null); + ByteArrayOutputStream responseBuffer = new ByteArrayOutputStream(); + setupResponse(responseBuffer, readParamsFailedCall, t, msg + "; " + t.getMessage()); + readParamsFailedCall.sendResponseIfReady(); + return; + } + + TraceInfo traceInfo = header.hasTraceInfo() ? new TraceInfo(header + .getTraceInfo().getTraceId(), header.getTraceInfo().getParentId()) + : null; + int timeout = 0; + if (header.hasTimeout() && header.getTimeout() > 0) { + timeout = Math.max(minClientRequestTimeout, header.getTimeout()); + } + Call call = new Call(id, this.service, md, header, param, cellScanner, + this, totalRequestSize, traceInfo, this.addr, timeout, + this.callCleanup); + + if (!scheduler.dispatch(new CallRunner(NettyRpcServer.this, call))) { + callQueueSizeInBytes.add(-1 * call.getSize()); + + ByteArrayOutputStream responseBuffer = new ByteArrayOutputStream(); + metrics.exception(CALL_QUEUE_TOO_BIG_EXCEPTION); + setupResponse(responseBuffer, call, CALL_QUEUE_TOO_BIG_EXCEPTION, + "Call queue is full on " + server.getServerName() + + ", too many items queued ?"); + call.sendResponseIfReady(); + } + } + + private void saslReadAndProcess(ByteBuff saslToken) throws IOException, + InterruptedException { + if (saslContextEstablished) { + if (LOG.isTraceEnabled()) + LOG.trace("Have read input token of size " + saslToken.limit() + + " for processing by saslServer.unwrap()"); + + if (!useWrap) { + processOneRpc(saslToken); + } else { + byte[] b = saslToken.hasArray() ? saslToken.array() : saslToken.toBytes(); + byte [] plaintextData; + if (useCryptoAesWrap) { + // unwrap with CryptoAES + plaintextData = cryptoAES.unwrap(b, 0, b.length); + } else { + plaintextData = saslServer.unwrap(b, 0, b.length); + } + processUnwrappedData(plaintextData); + } + } else { + byte[] replyToken; + try { + if (saslServer == null) { + switch (authMethod) { + case DIGEST: + if (secretManager == null) { + throw new AccessDeniedException( + "Server is not configured to do DIGEST authentication."); + } + saslServer = Sasl.createSaslServer(AuthMethod.DIGEST + .getMechanismName(), null, SaslUtil.SASL_DEFAULT_REALM, + HBaseSaslRpcServer.getSaslProps(), new SaslDigestCallbackHandler( + secretManager, this)); + break; + default: + UserGroupInformation current = UserGroupInformation.getCurrentUser(); + String fullName = current.getUserName(); + if (LOG.isDebugEnabled()) { + LOG.debug("Kerberos principal name is " + fullName); + } + final String names[] = SaslUtil.splitKerberosName(fullName); + if (names.length != 3) { + throw new AccessDeniedException( + "Kerberos principal name does NOT have the expected " + + "hostname part: " + fullName); + } + current.doAs(new PrivilegedExceptionAction() { + @Override + public Object run() throws SaslException { + saslServer = Sasl.createSaslServer(AuthMethod.KERBEROS + .getMechanismName(), names[0], names[1], + HBaseSaslRpcServer.getSaslProps(), new SaslGssCallbackHandler()); + return null; + } + }); + } + if (saslServer == null) + throw new AccessDeniedException( + "Unable to find SASL server implementation for " + + authMethod.getMechanismName()); + if (LOG.isDebugEnabled()) { + LOG.debug("Created SASL server with mechanism = " + authMethod.getMechanismName()); + } + } + if (LOG.isDebugEnabled()) { + LOG.debug("Have read input token of size " + saslToken.limit() + + " for processing by saslServer.evaluateResponse()"); + } + replyToken = saslServer + .evaluateResponse(saslToken.hasArray() ? saslToken.array() : saslToken.toBytes()); + } catch (IOException e) { + IOException sendToClient = e; + Throwable cause = e; + while (cause != null) { + if (cause instanceof InvalidToken) { + sendToClient = (InvalidToken) cause; + break; + } + cause = cause.getCause(); + } + doRawSaslReply(SaslStatus.ERROR, null, sendToClient.getClass().getName(), + sendToClient.getLocalizedMessage()); + metrics.authenticationFailure(); + String clientIP = this.toString(); + // attempting user could be null + AUDITLOG.warn(AUTH_FAILED_FOR + clientIP + ":" + attemptingUser); + throw e; + } + if (replyToken != null) { + if (LOG.isDebugEnabled()) { + LOG.debug("Will send token of size " + replyToken.length + + " from saslServer."); + } + doRawSaslReply(SaslStatus.SUCCESS, new BytesWritable(replyToken), null, + null); + } + if (saslServer.isComplete()) { + String qop = (String) saslServer.getNegotiatedProperty(Sasl.QOP); + useWrap = qop != null && !"auth".equalsIgnoreCase(qop); + ugi = getAuthorizedUgi(saslServer.getAuthorizationID()); + if (LOG.isDebugEnabled()) { + LOG.debug("SASL server context established. Authenticated client: " + + ugi + ". Negotiated QoP is " + + saslServer.getNegotiatedProperty(Sasl.QOP)); + } + metrics.authenticationSuccess(); + AUDITLOG.info(AUTH_SUCCESSFUL_FOR + ugi); + saslContextEstablished = true; + } + } + } + + private void processUnwrappedData(byte[] inBuf) throws IOException, + InterruptedException { + ReadableByteChannel ch = Channels.newChannel(new ByteArrayInputStream(inBuf)); + // Read all RPCs contained in the inBuf, even partial ones + while (true) { + int count; + if (unwrappedDataLengthBuffer.remaining() > 0) { + count = channelRead(ch, unwrappedDataLengthBuffer); + if (count <= 0 || unwrappedDataLengthBuffer.remaining() > 0) + return; + } + + if (unwrappedData == null) { + unwrappedDataLengthBuffer.flip(); + int unwrappedDataLength = unwrappedDataLengthBuffer.getInt(); + + if (unwrappedDataLength == RpcClient.PING_CALL_ID) { + if (LOG.isDebugEnabled()) + LOG.debug("Received ping message"); + unwrappedDataLengthBuffer.clear(); + continue; // ping message + } + unwrappedData = ByteBuffer.allocate(unwrappedDataLength); + } + + count = channelRead(ch, unwrappedData); + if (count <= 0 || unwrappedData.remaining() > 0) + return; + + if (unwrappedData.remaining() == 0) { + unwrappedDataLengthBuffer.clear(); + unwrappedData.flip(); + processOneRpc(new SingleByteBuff(unwrappedData)); + unwrappedData = null; + } + } + } + + private boolean authorizeConnection() throws IOException { + try { + // If auth method is DIGEST, the token was obtained by the + // real user for the effective user, therefore not required to + // authorize real user. doAs is allowed only for simple or kerberos + // authentication + if (ugi != null && ugi.getRealUser() != null + && (authMethod != AuthMethod.DIGEST)) { + ProxyUsers.authorize(ugi, this.getHostAddress(), conf); + } + authorize(ugi, connectionHeader, getHostInetAddress()); + metrics.authorizationSuccess(); + } catch (AuthorizationException ae) { + if (LOG.isDebugEnabled()) { + LOG.debug("Connection authorization failed: " + ae.getMessage(), ae); + } + metrics.authorizationFailure(); + setupResponse(authFailedResponse, authFailedCall, + new AccessDeniedException(ae), ae.getMessage()); + authFailedCall.sendResponseIfReady(); + return false; + } + return true; + } + + protected synchronized void close() { + disposeSasl(); + channel.close(); + callCleanup = null; + } + + @Override + public boolean isConnectionOpen() { + return channel.isOpen(); + } + } + + /** + * Datastructure that holds all necessary to a method invocation and then afterward, carries the + * result. + */ + @InterfaceStability.Evolving + public class Call extends RpcServer.Call { + + Call(int id, final BlockingService service, final MethodDescriptor md, + RequestHeader header, Message param, CellScanner cellScanner, + Connection connection, long size, TraceInfo tinfo, + final InetAddress remoteAddress, int timeout, CallCleanup reqCleanup) { + super(id, service, md, header, param, cellScanner, + connection, size, tinfo, remoteAddress, timeout, reqCleanup); + } + + @Override + public long disconnectSince() { + if (!getConnection().isConnectionOpen()) { + return System.currentTimeMillis() - timestamp; + } else { + return -1L; + } + } + + Connection getConnection() { + return (Connection) this.connection; + } + + /** + * If we have a response, and delay is not set, then respond immediately. Otherwise, do not + * respond to client. This is called by the RPC code in the context of the Handler thread. + */ + @Override + public synchronized void sendResponseIfReady() throws IOException { + getConnection().channel.writeAndFlush(this); + } + + public synchronized void sendResponseIfReady(ChannelFutureListener listener) throws IOException { + getConnection().channel.writeAndFlush(this).addListener(listener); + } + + } + + private class Initializer extends ChannelInitializer { + + @Override + protected void initChannel(SocketChannel channel) throws Exception { + ChannelPipeline pipeline = channel.pipeline(); + pipeline.addLast("header", new ConnectionHeaderHandler()); + pipeline.addLast("decoder", new NettyProtocolDecoder()); + pipeline.addLast("encoder", new MessageEncoder()); + } + + } + + public class ConnectionHeaderHandler extends ReplayingDecoder { + // If initial preamble with version and magic has been read or not. + private boolean connectionPreambleRead = false; + private Connection connection; + + public ConnectionHeaderHandler() { + super(State.CHECK_PROTOCOL_VERSION); + } + + private void readPreamble(ChannelHandlerContext ctx, ByteBuf input) throws IOException { + if (input.readableBytes() < 6) { + return; + } + connection = new Connection(ctx.channel()); + connection.readPreamble(input); + ((NettyProtocolDecoder) ctx.pipeline().get("decoder")).setConnection(connection); + connectionPreambleRead = true; + } + + @Override + protected void decode(ChannelHandlerContext ctx, ByteBuf byteBuf, List out) + throws Exception { + switch (state()) { + case CHECK_PROTOCOL_VERSION: { + readPreamble(ctx, byteBuf); + if (connectionPreambleRead) { + break; + } + checkpoint(State.READ_AUTH_SCHEMES); + } + } + ctx.pipeline().remove(this); + } + + } + + enum State { + CHECK_PROTOCOL_VERSION, READ_AUTH_SCHEMES + } + + class NettyProtocolDecoder extends ChannelInboundHandlerAdapter { + + private Connection connection; + ByteBuf cumulation; + + void setConnection(Connection connection) { + this.connection = connection; + } + + /** + * Returns the actual number of readable bytes in the internal cumulative + * buffer of this decoder. You usually do not need to rely on this value + * to write a decoder. Use it only when you must use it at your own risk. + * This method is a shortcut to {@link #internalBuffer() internalBuffer().readableBytes()}. + */ + protected int actualReadableBytes() { + return internalBuffer().readableBytes(); + } + + /** + * Returns the internal cumulative buffer of this decoder. You usually + * do not need to access the internal buffer directly to write a decoder. + * Use it only when you must use it at your own risk. + */ + protected ByteBuf internalBuffer() { + if (cumulation != null) { + return cumulation; + } else { + return Unpooled.EMPTY_BUFFER; + } + } + + @Override + public void channelActive(ChannelHandlerContext ctx) throws Exception { + allChannels.add(ctx.channel()); + super.channelActive(ctx); + } + + @Override + public void exceptionCaught(ChannelHandlerContext ctx, Throwable e) { + LOG.warn("Unexpected exception from downstream.", e); + allChannels.remove(ctx.channel()); + ctx.channel().close(); + } + + @Override + public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception { + RecyclableArrayList out = RecyclableArrayList.newInstance(); + try { + if (msg instanceof ByteBuf) { + ByteBuf data = (ByteBuf) msg; + metrics.receivedBytes(data.readableBytes()); + if (cumulation == null) { + cumulation = data; + try { + callDecode(ctx, cumulation, out); + } finally { + if (cumulation != null && !cumulation.isReadable()) { + cumulation.release(); + cumulation = null; + } + } + } else { + try { + if (cumulation.writerIndex() > cumulation.maxCapacity() - data.readableBytes()) { + ByteBuf oldCumulation = cumulation; + cumulation = ctx.alloc().buffer(oldCumulation.readableBytes() + data.readableBytes()); + cumulation.writeBytes(oldCumulation); + oldCumulation.release(); + } + cumulation.writeBytes(data); + callDecode(ctx, cumulation, out); + } finally { + if (cumulation != null) { + if (!cumulation.isReadable()) { + cumulation.release(); + cumulation = null; + } else { + cumulation.discardSomeReadBytes(); + } + } + data.release(); + } + } + } else { + out.add(msg); + } + } catch (DecoderException e) { + throw e; + } catch (Throwable t) { + throw new DecoderException(t); + } finally { + if (!out.isEmpty()) { + List results = new ArrayList(); + for (Object result : out) { + results.add(result); + } + ctx.fireChannelRead(results); + } + out.recycle(); + } + } + + @Override + public void channelInactive(ChannelHandlerContext ctx) throws Exception { + allChannels.remove(ctx.channel()); + RecyclableArrayList out = RecyclableArrayList.newInstance(); + try { + if (cumulation != null) { + callDecode(ctx, cumulation, out); + decodeLast(ctx, cumulation, out); + } else { + decodeLast(ctx, Unpooled.EMPTY_BUFFER, out); + } + } catch (DecoderException e) { + throw e; + } catch (Exception e) { + throw new DecoderException(e); + } finally { + if (cumulation != null) { + cumulation.release(); + cumulation = null; + } + + for (int i = 0; i < out.size(); i++) { + ctx.fireChannelRead(out.get(i)); + } + ctx.fireChannelInactive(); + } + } + + @Override + public final void handlerRemoved(ChannelHandlerContext ctx) throws Exception { + ByteBuf buf = internalBuffer(); + int readable = buf.readableBytes(); + if (buf.isReadable()) { + ByteBuf bytes = buf.readBytes(readable); + buf.release(); + ctx.fireChannelRead(bytes); + } + cumulation = null; + ctx.fireChannelReadComplete(); + } + + /** + * Called once data should be decoded from the given {@link ByteBuf}. This method will call + * {@link #decode(ChannelHandlerContext, ByteBuf, List)} as long as decoding should take place. + * + * @param ctx the {@link ChannelHandlerContext} which this {@link ByteToMessageDecoder} belongs to + * @param in the {@link ByteBuf} from which to read data + * @param out the {@link List} to which decoded messages should be added + */ + protected void callDecode(ChannelHandlerContext ctx, ByteBuf in, List out) { + try { + while (in.isReadable()) { + int outSize = out.size(); + int oldInputLength = in.readableBytes(); + decode(ctx, in, out); + + // Check if this handler was removed before try to continue the loop. + // If it was removed it is not safe to continue to operate on the buffer + // + // See https://github.com/netty/netty/issues/1664 + if (ctx.isRemoved()) { + break; + } + + if (outSize == out.size()) { + if (oldInputLength == in.readableBytes()) { + break; + } else { + continue; + } + } + + if (oldInputLength == in.readableBytes()) { + throw new DecoderException(StringUtil.simpleClassName(getClass()) + + ".decode() did not read anything but decoded a message."); + } + } + } catch (DecoderException e) { + throw e; + } catch (Throwable cause) { + throw new DecoderException(cause); + } + } + + protected void decode(ChannelHandlerContext ctx, ByteBuf buf, + List out) throws Exception { + ByteBuffer data = getData(buf); + if (data != null) { + connection.process(data); + } + } + + private ByteBuffer getData(ByteBuf buf) throws Exception { + // Make sure if the length field was received. + if (buf.readableBytes() < 4) { + // The length field was not received yet - return null. + // This method will be invoked again when more packets are + // received and appended to the buffer. + return null; + } + // The length field is in the buffer. + + // Mark the current buffer position before reading the length field + // because the whole frame might not be in the buffer yet. + // We will reset the buffer position to the marked position if + // there's not enough bytes in the buffer. + buf.markReaderIndex(); + + // Read the length field. + int length = buf.readInt(); + + if (length == RpcClient.PING_CALL_ID) { + if (!connection.useWrap) { // covers the !useSasl too + return null; // ping message + } + } + if (length < 0) { // A data length of zero is legal. + throw new IllegalArgumentException("Unexpected data length " + length + + "!! from " + connection.getHostAddress()); + } + if (length > maxRequestSize) { + String warningMsg = "data length is too large: " + length + "!! from " + + connection.getHostAddress() + ":" + connection.getRemotePort(); + LOG.warn(warningMsg); + throw new DoNotRetryIOException(warningMsg); + } + + // Make sure if there's enough bytes in the buffer. + if (buf.readableBytes() < length) { + // The whole bytes were not received yet - return null. + // This method will be invoked again when more packets are + // received and appended to the buffer. + + // Reset to the marked position to read the length field again + // next time. + buf.resetReaderIndex(); + return null; + } + // There's enough bytes in the buffer. Read it. + // TODO eliminate the copy. + ByteBuffer data = ByteBuffer.allocate(length); + buf.readBytes(data); + data.flip(); + return data; + } + + /** + * Is called one last time when the {@link ChannelHandlerContext} goes in-active. Which means the + * {@link #channelInactive(ChannelHandlerContext)} was triggered. + * By default this will just call {@link #decode(ChannelHandlerContext, ByteBuf, List)} but sub-classes may + * override this for some special cleanup operation. + */ + protected void decodeLast(ChannelHandlerContext ctx, ByteBuf in, List out) throws Exception { + decode(ctx, in, out); + } + } + + class MessageEncoder extends ChannelOutboundHandlerAdapter { + + @Override + public void write(ChannelHandlerContext ctx, Object msg, ChannelPromise promise) { + final Call call = (Call) msg; + ByteBuf response = Unpooled.wrappedBuffer(call.response.getBuffers()); + ctx.write(response, promise).addListener(new CallWriteListener(call)); + } + + } + + class CallWriteListener implements ChannelFutureListener { + private Call call; + + CallWriteListener(Call call) { + this.call = call; + } + + @Override + public void operationComplete(ChannelFuture future) throws Exception { + call.done(); + if (future.isSuccess()) { + metrics.sentBytes(call.response.size()); + } + } + + } + + @Override + public void setSocketSendBufSize(int size) { + } + + @Override + public int getNumOpenConnections() { + // allChannels also contains the server channel, so exclude that from the count. + return allChannels.size() - 1; + } + + @Override + public Pair call(BlockingService service, + MethodDescriptor md, Message param, CellScanner cellScanner, + long receiveTime, MonitoredRPCHandler status) throws IOException { + return call(service, md, param, cellScanner, receiveTime, status, + System.currentTimeMillis(), 0); + } + + @Override + public Pair call(BlockingService service, + MethodDescriptor md, Message param, CellScanner cellScanner, + long receiveTime, MonitoredRPCHandler status, long startTime, int timeout) + throws IOException { + Call fakeCall = new Call(-1, service, md, null, param, cellScanner, null, + -1, null, null, timeout, null); + fakeCall.setReceiveTime(receiveTime); + return call(fakeCall, status); + } + +} diff --git a/hbase-server/src/main/java/org/apache/hadoop/hbase/ipc/RpcServer.java b/hbase-server/src/main/java/org/apache/hadoop/hbase/ipc/RpcServer.java index d6a137b..7a67c0d 100644 --- a/hbase-server/src/main/java/org/apache/hadoop/hbase/ipc/RpcServer.java +++ b/hbase-server/src/main/java/org/apache/hadoop/hbase/ipc/RpcServer.java @@ -25,6 +25,9 @@ import java.io.IOException; import java.net.InetAddress; import java.net.InetSocketAddress; import java.nio.ByteBuffer; +import java.nio.channels.GatheringByteChannel; +import java.nio.channels.ReadableByteChannel; +import java.nio.channels.WritableByteChannel; import java.security.GeneralSecurityException; import java.util.ArrayList; import java.util.HashMap; @@ -34,6 +37,7 @@ import java.util.Properties; import java.util.concurrent.atomic.LongAdder; import javax.security.sasl.Sasl; +import javax.security.sasl.SaslException; import javax.security.sasl.SaslServer; import org.apache.commons.crypto.cipher.CryptoCipherFactory; @@ -459,7 +463,7 @@ public abstract class RpcServer implements RpcServerInterface, } } - private void setExceptionResponse(Throwable t, String errorMsg, + protected void setExceptionResponse(Throwable t, String errorMsg, ResponseHeader.Builder headerBuilder) { ExceptionResponse.Builder exceptionBuilder = ExceptionResponse.newBuilder(); exceptionBuilder.setExceptionClassName(t.getClass().getName()); @@ -477,7 +481,7 @@ public abstract class RpcServer implements RpcServerInterface, headerBuilder.setException(exceptionBuilder.build()); } - private ByteBuffer createHeaderAndMessageBytes(Message result, Message header, + protected ByteBuffer createHeaderAndMessageBytes(Message result, Message header, int cellBlockSize, List cellBlock) throws IOException { // Organize the response as a set of bytebuffers rather than collect it all together inside // one big byte array; save on allocations. @@ -550,7 +554,7 @@ public abstract class RpcServer implements RpcServerInterface, return pbBuf; } - private BufferChain wrapWithSasl(BufferChain bc) + protected BufferChain wrapWithSasl(BufferChain bc) throws IOException { if (!this.connection.useSasl) return bc; // Looks like no way around this; saslserver wants a byte array. I have to make it one. @@ -946,6 +950,17 @@ public abstract class RpcServer implements RpcServerInterface, return ugi; } + protected void disposeSasl() { + if (saslServer != null) { + try { + saslServer.dispose(); + saslServer = null; + } catch (SaslException ignored) { + // Ignored. This is being disposed of anyway. + } + } + } + public abstract boolean isConnectionOpen(); } @@ -1276,6 +1291,77 @@ public abstract class RpcServer implements RpcServerInterface, } /** + * When the read or write buffer size is larger than this limit, i/o will be + * done in chunks of this size. Most RPC requests and responses would be + * be smaller. + */ + protected static int NIO_BUFFER_LIMIT = 64 * 1024; //should not be more than 64KB. + + /** + * This is a wrapper around {@link java.nio.channels.ReadableByteChannel#read(java.nio.ByteBuffer)}. + * If the amount of data is large, it writes to channel in smaller chunks. + * This is to avoid jdk from creating many direct buffers as the size of + * ByteBuffer increases. There should not be any performance degredation. + * + * @param channel writable byte channel to write on + * @param buffer buffer to write + * @return number of bytes written + * @throws java.io.IOException e + * @see java.nio.channels.ReadableByteChannel#read(java.nio.ByteBuffer) + */ + protected int channelRead(ReadableByteChannel channel, + ByteBuffer buffer) throws IOException { + + int count = (buffer.remaining() <= NIO_BUFFER_LIMIT) ? + channel.read(buffer) : channelIO(channel, null, buffer); + if (count > 0) { + metrics.receivedBytes(count); + } + return count; + } + + /** + * Helper for {@link #channelRead(java.nio.channels.ReadableByteChannel, java.nio.ByteBuffer)} + * and {@link #channelWrite(GatheringByteChannel, BufferChain)}. Only + * one of readCh or writeCh should be non-null. + * + * @param readCh read channel + * @param writeCh write channel + * @param buf buffer to read or write into/out of + * @return bytes written + * @throws java.io.IOException e + * @see #channelRead(java.nio.channels.ReadableByteChannel, java.nio.ByteBuffer) + * @see #channelWrite(GatheringByteChannel, BufferChain) + */ + protected static int channelIO(ReadableByteChannel readCh, + WritableByteChannel writeCh, + ByteBuffer buf) throws IOException { + + int originalLimit = buf.limit(); + int initialRemaining = buf.remaining(); + int ret = 0; + + while (buf.remaining() > 0) { + try { + int ioSize = Math.min(buf.remaining(), NIO_BUFFER_LIMIT); + buf.limit(buf.position() + ioSize); + + ret = (readCh == null) ? writeCh.write(buf) : readCh.read(buf); + + if (ret < ioSize) { + break; + } + + } finally { + buf.limit(originalLimit); + } + } + + int nBytes = initialRemaining - buf.remaining(); + return (nBytes > 0) ? nBytes : ret; + } + + /** * This is extracted to a static method for better unit testing. We try to get buffer(s) from pool * as much as possible. * diff --git a/hbase-server/src/main/java/org/apache/hadoop/hbase/ipc/SimpleRpcServer.java b/hbase-server/src/main/java/org/apache/hadoop/hbase/ipc/SimpleRpcServer.java index 01d45cd..592b888 100644 --- a/hbase-server/src/main/java/org/apache/hadoop/hbase/ipc/SimpleRpcServer.java +++ b/hbase-server/src/main/java/org/apache/hadoop/hbase/ipc/SimpleRpcServer.java @@ -40,7 +40,6 @@ import java.nio.channels.SelectionKey; import java.nio.channels.Selector; import java.nio.channels.ServerSocketChannel; import java.nio.channels.SocketChannel; -import java.nio.channels.WritableByteChannel; import java.security.PrivilegedExceptionAction; import java.util.ArrayList; import java.util.Arrays; @@ -178,6 +177,7 @@ public class SimpleRpcServer extends RpcServer { */ @edu.umd.cs.findbugs.annotations.SuppressWarnings(value="IS2_INCONSISTENT_SYNC", justification="Presume the lock on processing request held by caller is protection enough") + @Override void done() { super.done(); this.getConnection().decRpcCount(); // Say that we're done with this call. @@ -192,6 +192,7 @@ public class SimpleRpcServer extends RpcServer { } } + @Override public synchronized void sendResponseIfReady() throws IOException { // set param null to reduce memory pressure this.param = null; @@ -999,17 +1000,6 @@ public class SimpleRpcServer extends RpcServer { } } - private void disposeSasl() { - if (saslServer != null) { - try { - saslServer.dispose(); - saslServer = null; - } catch (SaslException ignored) { - // Ignored. This is being disposed of anyway. - } - } - } - private int readPreamble() throws IOException { int count; // Check for 'HBas' magic. @@ -1702,12 +1692,14 @@ public class SimpleRpcServer extends RpcServer { return listener.getAddress(); } + @Override public Pair call(BlockingService service, MethodDescriptor md, Message param, CellScanner cellScanner, long receiveTime, MonitoredRPCHandler status) throws IOException { return call(service, md, param, cellScanner, receiveTime, status, System.currentTimeMillis(),0); } + @Override public Pair call(BlockingService service, MethodDescriptor md, Message param, CellScanner cellScanner, long receiveTime, MonitoredRPCHandler status, long startTime, int timeout) @@ -1719,13 +1711,6 @@ public class SimpleRpcServer extends RpcServer { } /** - * When the read or write buffer size is larger than this limit, i/o will be - * done in chunks of this size. Most RPC requests and responses would be - * be smaller. - */ - private static int NIO_BUFFER_LIMIT = 64 * 1024; //should not be more than 64KB. - - /** * This is a wrapper around {@link java.nio.channels.WritableByteChannel#write(java.nio.ByteBuffer)}. * If the amount of data is large, it writes to channel in smaller chunks. * This is to avoid jdk from creating many direct buffers as the size of @@ -1747,70 +1732,6 @@ public class SimpleRpcServer extends RpcServer { } /** - * This is a wrapper around {@link java.nio.channels.ReadableByteChannel#read(java.nio.ByteBuffer)}. - * If the amount of data is large, it writes to channel in smaller chunks. - * This is to avoid jdk from creating many direct buffers as the size of - * ByteBuffer increases. There should not be any performance degredation. - * - * @param channel writable byte channel to write on - * @param buffer buffer to write - * @return number of bytes written - * @throws java.io.IOException e - * @see java.nio.channels.ReadableByteChannel#read(java.nio.ByteBuffer) - */ - protected int channelRead(ReadableByteChannel channel, - ByteBuffer buffer) throws IOException { - - int count = (buffer.remaining() <= NIO_BUFFER_LIMIT) ? - channel.read(buffer) : channelIO(channel, null, buffer); - if (count > 0) { - metrics.receivedBytes(count); - } - return count; - } - - /** - * Helper for {@link #channelRead(java.nio.channels.ReadableByteChannel, java.nio.ByteBuffer)} - * and {@link #channelWrite(GatheringByteChannel, BufferChain)}. Only - * one of readCh or writeCh should be non-null. - * - * @param readCh read channel - * @param writeCh write channel - * @param buf buffer to read or write into/out of - * @return bytes written - * @throws java.io.IOException e - * @see #channelRead(java.nio.channels.ReadableByteChannel, java.nio.ByteBuffer) - * @see #channelWrite(GatheringByteChannel, BufferChain) - */ - protected static int channelIO(ReadableByteChannel readCh, - WritableByteChannel writeCh, - ByteBuffer buf) throws IOException { - - int originalLimit = buf.limit(); - int initialRemaining = buf.remaining(); - int ret = 0; - - while (buf.remaining() > 0) { - try { - int ioSize = Math.min(buf.remaining(), NIO_BUFFER_LIMIT); - buf.limit(buf.position() + ioSize); - - ret = (readCh == null) ? writeCh.write(buf) : readCh.read(buf); - - if (ret < ioSize) { - break; - } - - } finally { - buf.limit(originalLimit); - } - } - - int nBytes = initialRemaining - buf.remaining(); - return (nBytes > 0) ? nBytes : ret; - } - - /** * A convenience method to bind to a given address and report * better exceptions if the address is not a valid host. * @param socket the socket to bind diff --git a/hbase-server/src/test/java/org/apache/hadoop/hbase/ipc/TestProtoBufRpc.java b/hbase-server/src/test/java/org/apache/hadoop/hbase/ipc/TestProtoBufRpc.java index b039003..f21359c 100644 --- a/hbase-server/src/test/java/org/apache/hadoop/hbase/ipc/TestProtoBufRpc.java +++ b/hbase-server/src/test/java/org/apache/hadoop/hbase/ipc/TestProtoBufRpc.java @@ -24,6 +24,8 @@ import static org.junit.Assert.fail; import java.io.IOException; import java.net.InetSocketAddress; +import java.util.Arrays; +import java.util.Collection; import org.apache.hadoop.conf.Configuration; import org.apache.hadoop.hbase.HBaseConfiguration; @@ -40,6 +42,10 @@ import org.junit.After; import org.junit.Before; import org.junit.Test; import org.junit.experimental.categories.Category; +import org.junit.runner.RunWith; +import org.junit.runners.Parameterized; +import org.junit.runners.Parameterized.Parameter; +import org.junit.runners.Parameterized.Parameters; import com.google.common.collect.Lists; @@ -48,6 +54,7 @@ import com.google.common.collect.Lists; * of types in src/test/protobuf/test.proto and protobuf service definition from * src/test/protobuf/test_rpc_service.proto */ +@RunWith(Parameterized.class) @Category({ RPCTests.class, MediumTests.class }) public class TestProtoBufRpc { public final static String ADDRESS = "localhost"; @@ -56,9 +63,20 @@ public class TestProtoBufRpc { private Configuration conf; private RpcServerInterface server; + @Parameters(name = "{index}: rpcServerImpl={0}") + public static Collection parameters() { + return Arrays.asList(new Object[] { SimpleRpcServer.class.getName() }, + new Object[] { NettyRpcServer.class.getName() }); + } + + @Parameter(0) + public String rpcServerImpl; + @Before public void setUp() throws IOException { // Setup server for both protocols this.conf = HBaseConfiguration.create(); + this.conf.set(RpcServerFactory.CUSTOM_RPC_SERVER_IMPL_CONF_KEY, + rpcServerImpl); Logger log = Logger.getLogger("org.apache.hadoop.ipc.HBaseServer"); log.setLevel(Level.DEBUG); log = Logger.getLogger("org.apache.hadoop.ipc.HBaseServer.trace"); diff --git a/hbase-server/src/test/java/org/apache/hadoop/hbase/ipc/TestRpcHandlerException.java b/hbase-server/src/test/java/org/apache/hadoop/hbase/ipc/TestRpcHandlerException.java index 449899f..c12331f 100644 --- a/hbase-server/src/test/java/org/apache/hadoop/hbase/ipc/TestRpcHandlerException.java +++ b/hbase-server/src/test/java/org/apache/hadoop/hbase/ipc/TestRpcHandlerException.java @@ -23,11 +23,14 @@ import static org.mockito.Mockito.mock; import java.io.IOException; import java.net.InetSocketAddress; +import java.util.Arrays; +import java.util.Collection; import org.apache.hadoop.conf.Configuration; import org.apache.hadoop.hbase.Abortable; import org.apache.hadoop.hbase.HBaseConfiguration; import org.apache.hadoop.hbase.ipc.RpcServer.BlockingServiceAndInterface; +import org.apache.hadoop.hbase.shaded.com.google.protobuf.BlockingService; import org.apache.hadoop.hbase.shaded.ipc.protobuf.generated.TestProtos.EchoRequestProto; import org.apache.hadoop.hbase.shaded.ipc.protobuf.generated.TestRpcServiceProtos.TestProtobufRpcProto.BlockingInterface; import org.apache.hadoop.hbase.testclassification.RPCTests; @@ -35,11 +38,14 @@ import org.apache.hadoop.hbase.testclassification.SmallTests; import org.junit.Ignore; import org.junit.Test; import org.junit.experimental.categories.Category; +import org.junit.runner.RunWith; +import org.junit.runners.Parameterized; +import org.junit.runners.Parameterized.Parameter; +import org.junit.runners.Parameterized.Parameters; import com.google.common.collect.Lists; -import org.apache.hadoop.hbase.shaded.com.google.protobuf.BlockingService; - +@RunWith(Parameterized.class) @Category({ RPCTests.class, SmallTests.class }) public class TestRpcHandlerException { @@ -64,6 +70,15 @@ public class TestRpcHandlerException { } } + @Parameters(name = "{index}: rpcServerImpl={0}") + public static Collection parameters() { + return Arrays.asList(new Object[] { SimpleRpcServer.class.getName() }, + new Object[] { NettyRpcServer.class.getName() }); + } + + @Parameter(0) + public String rpcServerImpl; + /* * This is a unit test to make sure to abort region server when the number of Rpc handler thread * caught errors exceeds the threshold. Client will hang when RS aborts. @@ -73,6 +88,7 @@ public class TestRpcHandlerException { public void testRpcScheduler() throws IOException, InterruptedException { PriorityFunction qosFunction = mock(PriorityFunction.class); Abortable abortable = new AbortServer(); + CONF.set(RpcServerFactory.CUSTOM_RPC_SERVER_IMPL_CONF_KEY, rpcServerImpl); RpcScheduler scheduler = new SimpleRpcScheduler(CONF, 2, 0, 0, qosFunction, abortable, 0); RpcServer rpcServer = RpcServerFactory.createRpcServer(null, "testRpcServer", Lists.newArrayList(new BlockingServiceAndInterface((BlockingService) SERVICE, null)), diff --git a/hbase-server/src/test/java/org/apache/hadoop/hbase/security/TestSecureIPC.java b/hbase-server/src/test/java/org/apache/hadoop/hbase/security/TestSecureIPC.java index c848250..85a14f2 100644 --- a/hbase-server/src/test/java/org/apache/hadoop/hbase/security/TestSecureIPC.java +++ b/hbase-server/src/test/java/org/apache/hadoop/hbase/security/TestSecureIPC.java @@ -35,6 +35,7 @@ import java.util.ArrayList; import java.util.Arrays; import java.util.Collection; import java.util.Collections; +import java.util.List; import javax.security.sasl.SaslException; @@ -46,11 +47,14 @@ import org.apache.hadoop.hbase.HConstants; import org.apache.hadoop.hbase.ipc.BlockingRpcClient; import org.apache.hadoop.hbase.ipc.FifoRpcScheduler; import org.apache.hadoop.hbase.ipc.NettyRpcClient; +import org.apache.hadoop.hbase.ipc.NettyRpcServer; import org.apache.hadoop.hbase.ipc.RpcClient; import org.apache.hadoop.hbase.ipc.RpcClientFactory; import org.apache.hadoop.hbase.ipc.RpcServer; -import org.apache.hadoop.hbase.ipc.RpcServerInterface; import org.apache.hadoop.hbase.ipc.RpcServerFactory; +import org.apache.hadoop.hbase.ipc.RpcServerInterface; +import org.apache.hadoop.hbase.ipc.SimpleRpcServer; +import org.apache.hadoop.hbase.shaded.com.google.protobuf.BlockingService; import org.apache.hadoop.hbase.shaded.ipc.protobuf.generated.TestProtos; import org.apache.hadoop.hbase.shaded.ipc.protobuf.generated.TestRpcServiceProtos.TestProtobufRpcProto.BlockingInterface; import org.apache.hadoop.hbase.testclassification.SecurityTests; @@ -72,7 +76,6 @@ import org.junit.runners.Parameterized.Parameters; import org.mockito.Mockito; import com.google.common.collect.Lists; -import org.apache.hadoop.hbase.shaded.com.google.protobuf.BlockingService; @RunWith(Parameterized.class) @Category({ SecurityTests.class, SmallTests.class }) @@ -96,15 +99,27 @@ public class TestSecureIPC { @Rule public ExpectedException exception = ExpectedException.none(); - @Parameters(name = "{index}: rpcClientImpl={0}") + @Parameters(name = "{index}: rpcClientImpl={0}, rpcServerImpl={1}") public static Collection parameters() { - return Arrays.asList(new Object[]{BlockingRpcClient.class.getName()}, - new Object[]{NettyRpcClient.class.getName()}); + List params = new ArrayList<>(); + List rpcClientImpls = Arrays.asList( + BlockingRpcClient.class.getName(), NettyRpcClient.class.getName()); + List rpcServerImpls = Arrays.asList( + SimpleRpcServer.class.getName(), NettyRpcServer.class.getName()); + for (String rpcClientImpl : rpcClientImpls) { + for (String rpcServerImpl : rpcServerImpls) { + params.add(new Object[] { rpcClientImpl, rpcServerImpl }); + } + } + return params; } - @Parameter + @Parameter(0) public String rpcClientImpl; + @Parameter(1) + public String rpcServerImpl; + @BeforeClass public static void setUp() throws Exception { KDC = TEST_UTIL.setupMiniKdc(KEYTAB_FILE); @@ -129,6 +144,8 @@ public class TestSecureIPC { clientConf = getSecuredConfiguration(); clientConf.set(RpcClientFactory.CUSTOM_RPC_CLIENT_IMPL_CONF_KEY, rpcClientImpl); serverConf = getSecuredConfiguration(); + serverConf.set(RpcServerFactory.CUSTOM_RPC_SERVER_IMPL_CONF_KEY, + rpcServerImpl); } @Test diff --git a/hbase-server/src/test/java/org/apache/hadoop/hbase/security/token/TestTokenAuthentication.java b/hbase-server/src/test/java/org/apache/hadoop/hbase/security/token/TestTokenAuthentication.java index 92eaecc..5186235 100644 --- a/hbase-server/src/test/java/org/apache/hadoop/hbase/security/token/TestTokenAuthentication.java +++ b/hbase-server/src/test/java/org/apache/hadoop/hbase/security/token/TestTokenAuthentication.java @@ -27,6 +27,8 @@ import java.io.IOException; import java.io.InterruptedIOException; import java.net.InetSocketAddress; import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collection; import java.util.List; import java.util.concurrent.ConcurrentMap; import java.util.concurrent.ExecutorService; @@ -49,11 +51,13 @@ import org.apache.hadoop.hbase.client.ConnectionFactory; import org.apache.hadoop.hbase.client.Table; import org.apache.hadoop.hbase.coprocessor.RegionCoprocessorEnvironment; import org.apache.hadoop.hbase.ipc.FifoRpcScheduler; +import org.apache.hadoop.hbase.ipc.NettyRpcServer; import org.apache.hadoop.hbase.ipc.RpcServer; import org.apache.hadoop.hbase.ipc.RpcServer.BlockingServiceAndInterface; import org.apache.hadoop.hbase.ipc.RpcServerFactory; import org.apache.hadoop.hbase.ipc.RpcServerInterface; import org.apache.hadoop.hbase.ipc.ServerRpcController; +import org.apache.hadoop.hbase.ipc.SimpleRpcServer; import org.apache.hadoop.hbase.protobuf.generated.AuthenticationProtos; import org.apache.hadoop.hbase.regionserver.HRegion; import org.apache.hadoop.hbase.regionserver.RegionServerServices; @@ -79,10 +83,14 @@ import org.apache.hadoop.security.authorize.Service; import org.apache.hadoop.security.token.SecretManager; import org.apache.hadoop.security.token.Token; import org.apache.hadoop.security.token.TokenIdentifier; -import org.junit.AfterClass; -import org.junit.BeforeClass; +import org.junit.After; +import org.junit.Before; import org.junit.Test; import org.junit.experimental.categories.Category; +import org.junit.runner.RunWith; +import org.junit.runners.Parameterized; +import org.junit.runners.Parameterized.Parameter; +import org.junit.runners.Parameterized.Parameters; import com.google.protobuf.BlockingService; import com.google.protobuf.RpcController; @@ -96,6 +104,7 @@ import com.google.protobuf.ServiceException; // RpcServer is all about shaded protobuf whereas the Token Service is a CPEP which does non-shaded // protobufs. Since hbase-2.0.0, we added convertion from shaded to non-shaded so this test keeps // working. +@RunWith(Parameterized.class) @Category({SecurityTests.class, MediumTests.class}) public class TestTokenAuthentication { static { @@ -115,6 +124,7 @@ public class TestTokenAuthentication { AuthenticationProtos.AuthenticationService.BlockingInterface, Runnable, Server { private static final Log LOG = LogFactory.getLog(TokenServer.class); private Configuration conf; + private HBaseTestingUtility TEST_UTIL; private RpcServerInterface rpcServer; private InetSocketAddress isa; private ZooKeeperWatcher zookeeper; @@ -124,8 +134,10 @@ public class TestTokenAuthentication { private boolean stopped = false; private long startcode; - public TokenServer(Configuration conf) throws IOException { + public TokenServer(Configuration conf, HBaseTestingUtility TEST_UTIL) + throws IOException { this.conf = conf; + this.TEST_UTIL = TEST_UTIL; this.startcode = EnvironmentEdgeManager.currentTime(); // Server to handle client requests. String hostname = @@ -387,14 +399,23 @@ public class TestTokenAuthentication { } } - private static HBaseTestingUtility TEST_UTIL; - private static TokenServer server; - private static Thread serverThread; - private static AuthenticationTokenSecretManager secretManager; - private static ClusterId clusterId = new ClusterId(); + @Parameters(name = "{index}: rpcServerImpl={0}") + public static Collection parameters() { + return Arrays.asList(new Object[] { SimpleRpcServer.class.getName() }, + new Object[] { NettyRpcServer.class.getName() }); + } + + @Parameter(0) + public String rpcServerImpl; + + private HBaseTestingUtility TEST_UTIL; + private TokenServer server; + private Thread serverThread; + private AuthenticationTokenSecretManager secretManager; + private ClusterId clusterId = new ClusterId(); - @BeforeClass - public static void setupBeforeClass() throws Exception { + @Before + public void setUp() throws Exception { TEST_UTIL = new HBaseTestingUtility(); TEST_UTIL.startMiniZKCluster(); // register token type for protocol @@ -406,7 +427,8 @@ public class TestTokenAuthentication { conf.set("hadoop.security.authentication", "kerberos"); conf.set("hbase.security.authentication", "kerberos"); conf.setBoolean(HADOOP_SECURITY_AUTHORIZATION, true); - server = new TokenServer(conf); + conf.set(RpcServerFactory.CUSTOM_RPC_SERVER_IMPL_CONF_KEY, rpcServerImpl); + server = new TokenServer(conf, TEST_UTIL); serverThread = new Thread(server); Threads.setDaemonThreadRunning(serverThread, "TokenServer:"+server.getServerName().toString()); // wait for startup @@ -428,8 +450,8 @@ public class TestTokenAuthentication { } } - @AfterClass - public static void tearDownAfterClass() throws Exception { + @After + public void tearDown() throws Exception { server.stop("Test complete"); Threads.shutdown(serverThread); TEST_UTIL.shutdownMiniZKCluster();