diff --git src/main/java/org/apache/hadoop/hbase/coprocessor/AggregateImplementation.java src/main/java/org/apache/hadoop/hbase/coprocessor/AggregateImplementation.java index fce5490..ba3414d 100644 --- src/main/java/org/apache/hadoop/hbase/coprocessor/AggregateImplementation.java +++ src/main/java/org/apache/hadoop/hbase/coprocessor/AggregateImplementation.java @@ -28,6 +28,7 @@ import org.apache.commons.logging.LogFactory; import org.apache.hadoop.hbase.KeyValue; import org.apache.hadoop.hbase.client.Scan; import org.apache.hadoop.hbase.filter.FirstKeyOnlyFilter; +import org.apache.hadoop.hbase.ipc.ProtocolSignature; import org.apache.hadoop.hbase.regionserver.InternalScanner; import org.apache.hadoop.hbase.util.Pair; @@ -40,6 +41,16 @@ public class AggregateImplementation extends BaseEndpointCoprocessor implements protected static Log log = LogFactory.getLog(AggregateImplementation.class); @Override + public ProtocolSignature getProtocolSignature( + String protocol, long version, int clientMethodsHashCode) + throws IOException { + if (AggregateProtocol.class.getName().equals(protocol)) { + return new ProtocolSignature(AggregateProtocol.VERSION, null); + } + throw new IOException("Unknown protocol: " + protocol); + } + + @Override public T getMax(ColumnInterpreter ci, Scan scan) throws IOException { T temp; diff --git src/main/java/org/apache/hadoop/hbase/coprocessor/AggregateProtocol.java src/main/java/org/apache/hadoop/hbase/coprocessor/AggregateProtocol.java index 2fa4d6f..f25ba11 100644 --- src/main/java/org/apache/hadoop/hbase/coprocessor/AggregateProtocol.java +++ src/main/java/org/apache/hadoop/hbase/coprocessor/AggregateProtocol.java @@ -39,6 +39,7 @@ import org.apache.hadoop.hbase.util.Pair; * input parameters. */ public interface AggregateProtocol extends CoprocessorProtocol { + public static final long VERSION = 1L; /** * Gives the maximum for a given combination of column qualifier and column diff --git src/main/java/org/apache/hadoop/hbase/coprocessor/BaseEndpointCoprocessor.java src/main/java/org/apache/hadoop/hbase/coprocessor/BaseEndpointCoprocessor.java index 6f88357..3a787fd 100644 --- src/main/java/org/apache/hadoop/hbase/coprocessor/BaseEndpointCoprocessor.java +++ src/main/java/org/apache/hadoop/hbase/coprocessor/BaseEndpointCoprocessor.java @@ -21,6 +21,7 @@ import java.io.IOException; import org.apache.hadoop.hbase.Coprocessor; import org.apache.hadoop.hbase.CoprocessorEnvironment; import org.apache.hadoop.hbase.ipc.CoprocessorProtocol; +import org.apache.hadoop.hbase.ipc.ProtocolSignature; import org.apache.hadoop.hbase.ipc.VersionedProtocol; /** @@ -62,6 +63,13 @@ public abstract class BaseEndpointCoprocessor implements Coprocessor, public void stop(CoprocessorEnvironment env) { } @Override + public ProtocolSignature getProtocolSignature( + String protocol, long version, int clientMethodsHashCode) + throws IOException { + return new ProtocolSignature(VERSION, null); + } + + @Override public long getProtocolVersion(String protocol, long clientVersion) throws IOException { return VERSION; diff --git src/main/java/org/apache/hadoop/hbase/ipc/CoprocessorProtocol.java src/main/java/org/apache/hadoop/hbase/ipc/CoprocessorProtocol.java index 6fcb771..8211f03 100644 --- src/main/java/org/apache/hadoop/hbase/ipc/CoprocessorProtocol.java +++ src/main/java/org/apache/hadoop/hbase/ipc/CoprocessorProtocol.java @@ -19,8 +19,6 @@ */ package org.apache.hadoop.hbase.ipc; -import org.apache.hadoop.hbase.ipc.VersionedProtocol; - /** * All custom RPC protocols to be exported by Coprocessors must extend this interface. * @@ -37,4 +35,5 @@ import org.apache.hadoop.hbase.ipc.VersionedProtocol; *

*/ public interface CoprocessorProtocol extends VersionedProtocol { + public static final long VERSION = 1L; } diff --git src/main/java/org/apache/hadoop/hbase/ipc/HBaseClient.java src/main/java/org/apache/hadoop/hbase/ipc/HBaseClient.java index 1365411..4086829 100644 --- src/main/java/org/apache/hadoop/hbase/ipc/HBaseClient.java +++ src/main/java/org/apache/hadoop/hbase/ipc/HBaseClient.java @@ -566,6 +566,7 @@ public class HBaseClient { // Currently length if present is unused. in.readInt(); } + int state = in.readInt(); // Read the state. Currently unused. if (isError) { //noinspection ThrowableInstanceNeverThrown call.setException(new RemoteException( WritableUtils.readString(in), diff --git src/main/java/org/apache/hadoop/hbase/ipc/HBaseServer.java src/main/java/org/apache/hadoop/hbase/ipc/HBaseServer.java index 4a8918a..9117f12 100644 --- src/main/java/org/apache/hadoop/hbase/ipc/HBaseServer.java +++ src/main/java/org/apache/hadoop/hbase/ipc/HBaseServer.java @@ -21,6 +21,7 @@ package org.apache.hadoop.hbase.ipc; import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; import java.io.DataInputStream; import java.io.DataOutputStream; import java.io.IOException; @@ -65,6 +66,7 @@ import org.apache.hadoop.hbase.util.ByteBufferOutputStream; import org.apache.hadoop.hbase.util.Bytes; import org.apache.hadoop.io.Writable; import org.apache.hadoop.io.WritableUtils; +import org.apache.hadoop.ipc.RPC.VersionMismatch; import org.apache.hadoop.hbase.ipc.VersionedProtocol; import org.apache.hadoop.security.UserGroupInformation; import org.apache.hadoop.util.ReflectionUtils; @@ -88,7 +90,7 @@ public abstract class HBaseServer implements RpcServer { * The first four bytes of Hadoop RPC connections */ public static final ByteBuffer HEADER = ByteBuffer.wrap("hrpc".getBytes()); - public static final byte CURRENT_VERSION = 3; + public static final byte CURRENT_VERSION = 4; /** * How many calls/handler are allowed in the queue. @@ -273,8 +275,8 @@ public abstract class HBaseServer implements RpcServer { return param.toString() + " from " + connection.toString(); } - private synchronized void setResponse(Object value, String errorClass, - String error) { + private synchronized void setResponse(Object value, Status status, + String errorClass, String error) { // Avoid overwriting an error value in the response. This can happen if // endDelayThrowing is called by another thread before the actual call // returning. @@ -323,6 +325,7 @@ public abstract class HBaseServer implements RpcServer { // Place holder for length set later below after we // fill the buffer with data. out.writeInt(0xdeadbeef); + out.writeInt(status.state); } catch (IOException e) { errorClass = e.getClass().getName(); error = StringUtils.stringifyException(e); @@ -358,7 +361,7 @@ public abstract class HBaseServer implements RpcServer { this.delayResponse = false; delayedCalls.decrementAndGet(); if (this.delayReturnValue) - this.setResponse(result, null, null); + this.setResponse(result, Status.SUCCESS, null, null); this.responder.doRespond(this); } @@ -381,7 +384,7 @@ public abstract class HBaseServer implements RpcServer { @Override public synchronized void endDelayThrowing(Throwable t) throws IOException { - this.setResponse(null, t.getClass().toString(), + this.setResponse(null, Status.ERROR, t.getClass().toString(), StringUtils.stringifyException(t)); this.delayResponse = false; this.sendResponseIfReady(); @@ -443,8 +446,7 @@ public abstract class HBaseServer implements RpcServer { new ThreadFactoryBuilder().setNameFormat( "IPC Reader %d on port " + port).setDaemon(true).build()); for (int i = 0; i < readThreads; ++i) { - Selector readSelector = Selector.open(); - Reader reader = new Reader(readSelector); + Reader reader = new Reader(); readers[i] = reader; readPool.execute(reader); } @@ -458,40 +460,51 @@ public abstract class HBaseServer implements RpcServer { private class Reader implements Runnable { private volatile boolean adding = false; - private Selector readSelector = null; + private final Selector readSelector; - Reader(Selector readSelector) { - this.readSelector = readSelector; + Reader() throws IOException { + this.readSelector = Selector.open(); } public void run() { - synchronized(this) { - while (running) { - SelectionKey key = null; - try { - readSelector.select(); - while (adding) { - this.wait(1000); - } + LOG.info("Starting " + getName()); + try { + doRunLoop(); + } finally { + try { + readSelector.close(); + } catch (IOException ioe) { + LOG.error("Error closing read selector in " + getName(), ioe); + } + } + } + + private synchronized void doRunLoop() { + while (running) { + SelectionKey key = null; + try { + readSelector.select(); + while (adding) { + this.wait(1000); + } - Iterator iter = readSelector.selectedKeys().iterator(); - while (iter.hasNext()) { - key = iter.next(); - iter.remove(); - if (key.isValid()) { - if (key.isReadable()) { - doRead(key); - } + Iterator iter = readSelector.selectedKeys().iterator(); + while (iter.hasNext()) { + key = iter.next(); + iter.remove(); + if (key.isValid()) { + if (key.isReadable()) { + doRead(key); } - key = null; } - } catch (InterruptedException e) { - if (running) { // unexpected -- log it - LOG.info(getName() + "caught: " + - StringUtils.stringifyException(e)); - } - } catch (IOException ex) { - LOG.error("Error in Reader", ex); + key = null; + } + } catch (InterruptedException e) { + if (running) { // unexpected -- log it + LOG.info(getName() + " unexpectedly interrupted: " + + StringUtils.stringifyException(e)); } + } catch (IOException ex) { + LOG.error("Error in Reader", ex); } } } @@ -730,7 +743,7 @@ public abstract class HBaseServer implements RpcServer { // Sends responses of RPC back to clients. private class Responder extends Thread { - private Selector writeSelector; + private final Selector writeSelector; private int pending; // connections waiting to register final static int PURGE_INTERVAL = 900000; // 15mins @@ -746,6 +759,19 @@ public abstract class HBaseServer implements RpcServer { public void run() { LOG.info(getName() + ": starting"); SERVER.set(HBaseServer.this); + try { + doRunLoop(); + } finally { + LOG.info("Stopping " + this.getName()); + try { + writeSelector.close(); + } catch (IOException ioe) { + LOG.error("Couldn't close write selector in " + this.getName(), ioe); + } + } + } + + private void doRunLoop() { long lastPurgeTime = 0; // last check for old calls. while (running) { @@ -1106,6 +1132,7 @@ public abstract class HBaseServer implements RpcServer { hostAddress + ":" + remotePort + " got version " + version + " expected version " + CURRENT_VERSION); + setupBadVersionResponse(version); return -1; } dataLengthBuffer.clear(); @@ -1144,6 +1171,30 @@ public abstract class HBaseServer implements RpcServer { } } + /** + * Try to set up the response to indicate that the client version + * is incompatible with the server. This can contain special-case + * code to speak enough of past IPC protocols to pass back + * an exception to the caller. + * @param clientVersion the version the caller is using + * @throws IOException + */ + private void setupBadVersionResponse(int clientVersion) throws IOException { + String errMsg = "Server IPC version " + CURRENT_VERSION + + " cannot communicate with client version " + clientVersion; + ByteArrayOutputStream buffer = new ByteArrayOutputStream(); + + if (clientVersion >= 3) { + Call fakeCall = new Call(-1, null, this, responder); + // Versions 3 and greater can interpret this exception + // response in the same manner + setupResponse(buffer, fakeCall, Status.FATAL, + null, VersionMismatch.class.getName(), errMsg); + + responder.doRespond(fakeCall); + } + } + /// Reads the connection header following version private void processHeader() throws IOException { DataInputStream in = @@ -1170,9 +1221,22 @@ public abstract class HBaseServer implements RpcServer { if (LOG.isDebugEnabled()) LOG.debug(" got call #" + id + ", " + array.length + " bytes"); - Writable param = ReflectionUtils.newInstance(paramClass, conf); // read param - param.readFields(dis); - + Writable param; + try { + param = ReflectionUtils.newInstance(paramClass, conf);//read param + param.readFields(dis); + } catch (Throwable t) { + LOG.warn("Unable to read call parameters for client " + + getHostAddress(), t); + final Call readParamsFailedCall = new Call(id, null, this, responder); + ByteArrayOutputStream responseBuffer = new ByteArrayOutputStream(); + + setupResponse(responseBuffer, readParamsFailedCall, Status.FATAL, null, + t.getClass().getName(), + "IPC server unable to read call parameters: " + t.getMessage()); + responder.doRespond(readParamsFailedCall); + return; + } Call call = new Call(id, param, this, responder); if (priorityCallQueue != null && getQosLevel(param) > highPriorityLevel) { @@ -1251,7 +1315,9 @@ public abstract class HBaseServer implements RpcServer { // Set the response for undelayed calls and delayed calls with // undelayed responses. if (!call.isDelayed() || !call.isReturnValueDelayed()) { - call.setResponse(value, errorClass, error); + call.setResponse(value, + errorClass == null? Status.SUCCESS: Status.ERROR, + errorClass, error); } call.sendResponseIfReady(); } catch (InterruptedException e) { @@ -1356,6 +1422,41 @@ public abstract class HBaseServer implements RpcServer { responder = new Responder(); } + /** + * Setup response for the IPC Call. + * + * @param response buffer to serialize the response into + * @param call {@link Call} to which we are setting up the response + * @param status {@link Status} of the IPC call + * @param rv return value for the IPC Call, if the call was successful + * @param errorClass error class, if the the call failed + * @param error error message, if the call failed + * @throws IOException + */ + private void setupResponse(ByteArrayOutputStream response, + Call call, Status status, + Writable rv, String errorClass, String error) + throws IOException { + response.reset(); + DataOutputStream out = new DataOutputStream(response); + + if (status == Status.SUCCESS) { + try { + rv.write(out); + call.setResponse(rv, status, null, null); + } catch (Throwable t) { + LOG.warn("Error serializing call response for call " + call, t); + // Call back to same function - this is OK since the + // buffer is reset at the top, and since status is changed + // to ERROR it won't infinite loop. + call.setResponse(null, status.ERROR, t.getClass().getName(), + StringUtils.stringifyException(t)); + } + } else { + call.setResponse(rv, status, errorClass, error); + } + } + protected void closeConnection(Connection connection) { synchronized (connectionList) { if (connectionList.remove(connection)) diff --git src/main/java/org/apache/hadoop/hbase/ipc/Invocation.java src/main/java/org/apache/hadoop/hbase/ipc/Invocation.java index e60f970..f04160f 100644 --- src/main/java/org/apache/hadoop/hbase/ipc/Invocation.java +++ src/main/java/org/apache/hadoop/hbase/ipc/Invocation.java @@ -22,20 +22,25 @@ package org.apache.hadoop.hbase.ipc; import org.apache.hadoop.conf.Configurable; import org.apache.hadoop.conf.Configuration; import org.apache.hadoop.hbase.io.HbaseObjectWritable; -import org.apache.hadoop.io.Writable; +import org.apache.hadoop.io.VersionedWritable; import java.io.DataInput; import java.io.DataOutput; import java.io.IOException; +import java.lang.reflect.Field; import java.lang.reflect.Method; /** A method invocation, including the method name and its parameters.*/ -public class Invocation implements Writable, Configurable { +public class Invocation extends VersionedWritable implements Configurable { protected String methodName; @SuppressWarnings("unchecked") protected Class[] parameterClasses; protected Object[] parameters; protected Configuration conf; + private long clientVersion; + private int clientMethodsHash; + + private static byte RPC_VERSION = 1; public Invocation() {} @@ -43,6 +48,23 @@ public class Invocation implements Writable, Configurable { this.methodName = method.getName(); this.parameterClasses = method.getParameterTypes(); this.parameters = parameters; + if (method.getDeclaringClass().equals(VersionedProtocol.class)) { + //VersionedProtocol is exempted from version check. + clientVersion = 0; + clientMethodsHash = 0; + } else { + try { + Field versionField = method.getDeclaringClass().getField("VERSION"); + versionField.setAccessible(true); + this.clientVersion = versionField.getLong(method.getDeclaringClass()); + } catch (NoSuchFieldException ex) { + throw new RuntimeException("The " + method.getDeclaringClass(), ex); + } catch (IllegalAccessException ex) { + throw new RuntimeException(ex); + } + this.clientMethodsHash = ProtocolSignature.getFingerprint(method + .getDeclaringClass().getMethods()); + } } /** @return The name of the method invoked. */ @@ -55,8 +77,31 @@ public class Invocation implements Writable, Configurable { /** @return The parameter instances. */ public Object[] getParameters() { return parameters; } + long getProtocolVersion() { + return clientVersion; + } + + protected int getClientMethodsHash() { + return clientMethodsHash; + } + + /** + * Returns the rpc version used by the client. + * @return rpcVersion + */ + public long getRpcVersion() { + return RPC_VERSION; + } + public void readFields(DataInput in) throws IOException { + super.readFields(in); + if (getVersion() != RPC_VERSION) { + throw new IOException("Unknown version; expected=" + RPC_VERSION + + ", got=" + getVersion()); + } methodName = in.readUTF(); + clientVersion = in.readLong(); + clientMethodsHash = in.readInt(); parameters = new Object[in.readInt()]; parameterClasses = new Class[parameters.length]; HbaseObjectWritable objectWritable = new HbaseObjectWritable(); @@ -68,7 +113,10 @@ public class Invocation implements Writable, Configurable { } public void write(DataOutput out) throws IOException { + out.writeByte(getVersion()); out.writeUTF(this.methodName); + out.writeLong(clientVersion); + out.writeInt(clientMethodsHash); out.writeInt(parameterClasses.length); for (int i = 0; i < parameterClasses.length; i++) { HbaseObjectWritable.writeObject(out, parameters[i], parameterClasses[i], @@ -87,6 +135,9 @@ public class Invocation implements Writable, Configurable { buffer.append(parameters[i]); } buffer.append(")"); + buffer.append(", rpc version="+RPC_VERSION); + buffer.append(", client version="+clientVersion); + buffer.append(", methodsFingerPrint="+clientMethodsHash); return buffer.toString(); } @@ -98,4 +149,8 @@ public class Invocation implements Writable, Configurable { return this.conf; } + @Override + public byte getVersion() { + return RPC_VERSION; + } } diff --git src/main/java/org/apache/hadoop/hbase/ipc/ProtocolSignature.java src/main/java/org/apache/hadoop/hbase/ipc/ProtocolSignature.java new file mode 100644 index 0000000..f345cee --- /dev/null +++ src/main/java/org/apache/hadoop/hbase/ipc/ProtocolSignature.java @@ -0,0 +1,241 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.hadoop.hbase.ipc; + +import java.io.DataInput; +import java.io.DataOutput; +import java.io.IOException; +import java.lang.reflect.Method; +import java.util.Arrays; +import java.util.HashMap; + +import org.apache.hadoop.io.Writable; +import org.apache.hadoop.io.WritableFactories; +import org.apache.hadoop.io.WritableFactory; + +public class ProtocolSignature implements Writable { + static { // register a ctor + WritableFactories.setFactory + (ProtocolSignature.class, + new WritableFactory() { + public Writable newInstance() { return new ProtocolSignature(); } + }); + } + + private long version; + private int[] methods = null; // an array of method hash codes + + /** + * default constructor + */ + public ProtocolSignature() { + } + + /** + * Constructor + * + * @param version server version + * @param methodHashcodes hash codes of the methods supported by server + */ + public ProtocolSignature(long version, int[] methodHashcodes) { + this.version = version; + this.methods = methodHashcodes; + } + + public long getVersion() { + return version; + } + + public int[] getMethods() { + return methods; + } + + @Override + public void readFields(DataInput in) throws IOException { + version = in.readLong(); + boolean hasMethods = in.readBoolean(); + if (hasMethods) { + int numMethods = in.readInt(); + methods = new int[numMethods]; + for (int i=0; i type : method.getParameterTypes()) { + hashcode = 31*hashcode ^ type.getName().hashCode(); + } + return hashcode; + } + + /** + * Convert an array of Method into an array of hash codes + * + * @param methods + * @return array of hash codes + */ + private static int[] getFingerprints(Method[] methods) { + if (methods == null) { + return null; + } + int[] hashCodes = new int[methods.length]; + for (int i = 0; i + PROTOCOL_FINGERPRINT_CACHE = + new HashMap(); + + /** + * Return a protocol's signature and finger print from cache + * + * @param protocol a protocol class + * @param serverVersion protocol version + * @return its signature and finger print + */ + private static ProtocolSigFingerprint getSigFingerprint( + Class protocol, long serverVersion) { + String protocolName = protocol.getName(); + synchronized (PROTOCOL_FINGERPRINT_CACHE) { + ProtocolSigFingerprint sig = PROTOCOL_FINGERPRINT_CACHE.get(protocolName); + if (sig == null) { + int[] serverMethodHashcodes = getFingerprints(protocol.getMethods()); + sig = new ProtocolSigFingerprint( + new ProtocolSignature(serverVersion, serverMethodHashcodes), + getFingerprint(serverMethodHashcodes)); + PROTOCOL_FINGERPRINT_CACHE.put(protocolName, sig); + } + return sig; + } + } + + /** + * Get a server protocol's signature + * + * @param clientMethodsHashCode client protocol methods hashcode + * @param serverVersion server protocol version + * @param protocol protocol + * @return the server's protocol signature + */ + static ProtocolSignature getProtocolSignature( + int clientMethodsHashCode, + long serverVersion, + Class protocol) { + // try to get the finger print & signature from the cache + ProtocolSigFingerprint sig = getSigFingerprint(protocol, serverVersion); + + // check if the client side protocol matches the one on the server side + if (clientMethodsHashCode == sig.fingerprint) { + return new ProtocolSignature(serverVersion, null); // null indicates a match + } + + return sig.signature; + } + + /** + * Get a server protocol's signature + * + * @param server server implementation + * @param protocol server protocol + * @param clientVersion client's version + * @param clientMethodsHash client's protocol's hash code + * @return the server protocol's signature + * @throws IOException if any error occurs + */ + @SuppressWarnings("unchecked") + public static ProtocolSignature getProtocolSignature(VersionedProtocol server, + String protocol, + long clientVersion, int clientMethodsHash) throws IOException { + Class inter; + try { + inter = (Class)Class.forName(protocol); + } catch (Exception e) { + throw new IOException(e); + } + long serverVersion = server.getProtocolVersion(protocol, clientVersion); + return ProtocolSignature.getProtocolSignature( + clientMethodsHash, serverVersion, inter); + } +} diff --git src/main/java/org/apache/hadoop/hbase/ipc/Status.java src/main/java/org/apache/hadoop/hbase/ipc/Status.java new file mode 100644 index 0000000..c61282f --- /dev/null +++ src/main/java/org/apache/hadoop/hbase/ipc/Status.java @@ -0,0 +1,32 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.hadoop.hbase.ipc; + +/** + * Status of a Hadoop IPC call. + */ +enum Status { + SUCCESS (0), + ERROR (1), + FATAL (-1); + + int state; + private Status(int state) { + this.state = state; + } +} diff --git src/main/java/org/apache/hadoop/hbase/ipc/VersionedProtocol.java src/main/java/org/apache/hadoop/hbase/ipc/VersionedProtocol.java index fb07374..9568b1b 100644 --- src/main/java/org/apache/hadoop/hbase/ipc/VersionedProtocol.java +++ src/main/java/org/apache/hadoop/hbase/ipc/VersionedProtocol.java @@ -24,12 +24,9 @@ import java.io.IOException; * Superclass of all protocols that use Hadoop RPC. * Subclasses of this interface are also supposed to have * a static final long versionID field. - * - * This has been copied from the Hadoop IPC project so that - * we can run on multiple different versions of Hadoop. */ public interface VersionedProtocol { - + /** * Return protocol version corresponding to protocol interface. * @param protocol The classname of the protocol interface @@ -40,4 +37,18 @@ public interface VersionedProtocol { @Deprecated public long getProtocolVersion(String protocol, long clientVersion) throws IOException; + + /** + * Return protocol version corresponding to protocol interface. + * @param protocol The classname of the protocol interface + * @param clientVersion The version of the protocol that the client speaks + * @param clientMethodsHash the hashcode of client protocol methods + * @return the server protocol signature containing its version and + * a list of its supported methods + * @see ProtocolSignature#getProtocolSignature(VersionedProtocol, String, + * long, int) for a default implementation + */ + public ProtocolSignature getProtocolSignature(String protocol, + long clientVersion, + int clientMethodsHash) throws IOException; } diff --git src/main/java/org/apache/hadoop/hbase/ipc/WritableRpcEngine.java src/main/java/org/apache/hadoop/hbase/ipc/WritableRpcEngine.java index 60a9248..79c4b19 100644 --- src/main/java/org/apache/hadoop/hbase/ipc/WritableRpcEngine.java +++ src/main/java/org/apache/hadoop/hbase/ipc/WritableRpcEngine.java @@ -43,6 +43,7 @@ import org.apache.hadoop.hbase.regionserver.HRegionServer; import org.apache.hadoop.hbase.util.Bytes; import org.apache.hadoop.hbase.util.Objects; import org.apache.hadoop.io.*; +import org.apache.hadoop.ipc.RPC; import org.apache.hadoop.hbase.ipc.VersionedProtocol; import org.apache.hadoop.security.UserGroupInformation; import org.apache.hadoop.security.authorize.ServiceAuthorizationManager; @@ -56,6 +57,10 @@ class WritableRpcEngine implements RpcEngine { // DEBUG log level does NOT emit RPC-level logging. private static final Log LOG = LogFactory.getLog("org.apache.hadoop.ipc.RPCEngine"); + //writableRpcVersion should be updated if there is a change + //in format of the rpc messages. + public static long writableRpcVersion = 1L; + /* Cache a client using its socket factory as the hash key */ static private class ClientCache { private Map clients = @@ -335,6 +340,30 @@ class WritableRpcEngine implements RpcEngine { call.getParameterClasses()); method.setAccessible(true); + // Verify rpc version + if (call.getRpcVersion() != writableRpcVersion) { + // Client is using a different version of WritableRpc + throw new IOException( + "WritableRpc version mismatch, client side version=" + + call.getRpcVersion() + ", server side version=" + + writableRpcVersion); + } + + //Verify protocol version. + //Bypass the version check for VersionedProtocol + if (!method.getDeclaringClass().equals(VersionedProtocol.class)) { + long clientVersion = call.getProtocolVersion(); + ProtocolSignature serverInfo = ((VersionedProtocol) instance) + .getProtocolSignature(protocol.getCanonicalName(), call + .getProtocolVersion(), call.getClientMethodsHash()); + long serverVersion = serverInfo.getVersion(); + if (serverVersion != clientVersion) { + LOG.warn("Version mismatch: client version=" + clientVersion + + ", server version=" + serverVersion); + throw new RPC.VersionMismatch(protocol.getName(), clientVersion, + serverVersion); + } + } Object impl = null; if (protocol.isAssignableFrom(this.implementation)) { impl = this.instance; diff --git src/main/java/org/apache/hadoop/hbase/master/AssignmentManager.java src/main/java/org/apache/hadoop/hbase/master/AssignmentManager.java index 8de2314..a55686c 100644 --- src/main/java/org/apache/hadoop/hbase/master/AssignmentManager.java +++ src/main/java/org/apache/hadoop/hbase/master/AssignmentManager.java @@ -61,7 +61,6 @@ import org.apache.hadoop.hbase.executor.EventHandler.EventType; import org.apache.hadoop.hbase.executor.ExecutorService; import org.apache.hadoop.hbase.executor.RegionTransitionData; import org.apache.hadoop.hbase.ipc.ServerNotRunningYetException; -import org.apache.hadoop.hbase.master.AssignmentManager.RegionState; import org.apache.hadoop.hbase.master.handler.ClosedRegionHandler; import org.apache.hadoop.hbase.master.handler.DisableTableHandler; import org.apache.hadoop.hbase.master.handler.EnableTableHandler; diff --git src/main/java/org/apache/hadoop/hbase/master/HMaster.java src/main/java/org/apache/hadoop/hbase/master/HMaster.java index 0d0e4c5..74ee0be 100644 --- src/main/java/org/apache/hadoop/hbase/master/HMaster.java +++ src/main/java/org/apache/hadoop/hbase/master/HMaster.java @@ -61,6 +61,7 @@ import org.apache.hadoop.hbase.ipc.HBaseRPC; import org.apache.hadoop.hbase.ipc.HBaseServer; import org.apache.hadoop.hbase.ipc.HMasterInterface; import org.apache.hadoop.hbase.ipc.HMasterRegionInterface; +import org.apache.hadoop.hbase.ipc.ProtocolSignature; import org.apache.hadoop.hbase.ipc.RpcServer; import org.apache.hadoop.hbase.master.handler.CreateTableHandler; import org.apache.hadoop.hbase.master.handler.DeleteTableHandler; @@ -561,6 +562,18 @@ implements HMasterInterface, HMasterRegionInterface, MasterServices, Server { return assigned; } + @Override + public ProtocolSignature getProtocolSignature( + String protocol, long version, int clientMethodsHashCode) + throws IOException { + if (HMasterInterface.class.getName().equals(protocol)) { + return new ProtocolSignature(HMasterInterface.VERSION, null); + } else if (HMasterRegionInterface.class.getName().equals(protocol)) { + return new ProtocolSignature(HMasterRegionInterface.VERSION, null); + } + throw new IOException("Unknown protocol: " + protocol); + } + public long getProtocolVersion(String protocol, long clientVersion) { if (HMasterInterface.class.getName().equals(protocol)) { return HMasterInterface.VERSION; diff --git src/main/java/org/apache/hadoop/hbase/regionserver/HRegionServer.java src/main/java/org/apache/hadoop/hbase/regionserver/HRegionServer.java index 12bd33e..8698be7 100644 --- src/main/java/org/apache/hadoop/hbase/regionserver/HRegionServer.java +++ src/main/java/org/apache/hadoop/hbase/regionserver/HRegionServer.java @@ -112,6 +112,7 @@ import org.apache.hadoop.hbase.ipc.HBaseRpcMetrics; import org.apache.hadoop.hbase.ipc.HMasterRegionInterface; import org.apache.hadoop.hbase.ipc.HRegionInterface; import org.apache.hadoop.hbase.ipc.Invocation; +import org.apache.hadoop.hbase.ipc.ProtocolSignature; import org.apache.hadoop.hbase.ipc.RpcServer; import org.apache.hadoop.hbase.ipc.ServerNotRunningYetException; import org.apache.hadoop.hbase.regionserver.Leases.LeaseStillHeldException; @@ -2871,6 +2872,17 @@ public class HRegionServer implements HRegionInterface, HBaseRPCErrorHandler, @Override @QosPriority(priority=HIGH_QOS) + public ProtocolSignature getProtocolSignature( + String protocol, long version, int clientMethodsHashCode) + throws IOException { + if (protocol.equals(HRegionInterface.class.getName())) { + return new ProtocolSignature(HRegionInterface.VERSION, null); + } + throw new IOException("Unknown protocol: " + protocol); + } + + @Override + @QosPriority(priority=HIGH_QOS) public long getProtocolVersion(final String protocol, final long clientVersion) throws IOException { if (protocol.equals(HRegionInterface.class.getName())) { diff --git src/test/java/org/apache/hadoop/hbase/ipc/TestDelayedRpc.java src/test/java/org/apache/hadoop/hbase/ipc/TestDelayedRpc.java index 888f428..4902ed5 100644 --- src/test/java/org/apache/hadoop/hbase/ipc/TestDelayedRpc.java +++ src/test/java/org/apache/hadoop/hbase/ipc/TestDelayedRpc.java @@ -25,6 +25,7 @@ import static org.junit.Assert.assertTrue; import static org.junit.Assert.fail; import java.io.IOException; +import java.lang.reflect.Method; import java.net.InetSocketAddress; import java.util.ArrayList; import java.util.List; @@ -85,8 +86,8 @@ public class TestDelayedRpc { th2.join(); th3.join(); - assertEquals(results.get(0).intValue(), UNDELAYED); - assertEquals(results.get(1).intValue(), UNDELAYED); + assertEquals(UNDELAYED, results.get(0).intValue()); + assertEquals(UNDELAYED, results.get(1).intValue()); assertEquals(results.get(2).intValue(), delayReturnValue ? DELAYED : 0xDEADBEEF); } @@ -157,6 +158,7 @@ public class TestDelayedRpc { } public interface TestRpc extends VersionedProtocol { + public static final long VERSION = 1L; int test(boolean delay); } @@ -202,6 +204,17 @@ public class TestDelayedRpc { public long getProtocolVersion(String arg0, long arg1) throws IOException { return 0; } + + @Override + public ProtocolSignature getProtocolSignature(String protocol, + long clientVersion, int clientMethodsHash) throws IOException { + Method [] methods = this.getClass().getMethods(); + int [] hashes = new int [methods.length]; + for (int i = 0; i < methods.length; i++) { + hashes[i] = methods[i].hashCode(); + } + return new ProtocolSignature(clientVersion, hashes); + } } private static class TestThread extends Thread { @@ -283,5 +296,11 @@ public class TestDelayedRpc { public long getProtocolVersion(String arg0, long arg1) throws IOException { return 0; } + + @Override + public ProtocolSignature getProtocolSignature(String protocol, + long clientVersion, int clientMethodsHash) throws IOException { + return new ProtocolSignature(clientVersion, new int [] {}); + } } } diff --git src/test/java/org/apache/hadoop/hbase/regionserver/TestServerCustomProtocol.java src/test/java/org/apache/hadoop/hbase/regionserver/TestServerCustomProtocol.java index e5b6a78..61058c9 100644 --- src/test/java/org/apache/hadoop/hbase/regionserver/TestServerCustomProtocol.java +++ src/test/java/org/apache/hadoop/hbase/regionserver/TestServerCustomProtocol.java @@ -39,6 +39,9 @@ import org.apache.hadoop.hbase.client.Row; import org.apache.hadoop.hbase.client.coprocessor.Batch; import org.apache.hadoop.hbase.coprocessor.CoprocessorHost; import org.apache.hadoop.hbase.ipc.CoprocessorProtocol; +import org.apache.hadoop.hbase.ipc.HMasterInterface; +import org.apache.hadoop.hbase.ipc.HMasterRegionInterface; +import org.apache.hadoop.hbase.ipc.ProtocolSignature; import org.apache.hadoop.hbase.util.Bytes; import org.apache.hadoop.hbase.util.JVMClusterUtil; import org.apache.hadoop.hbase.ipc.VersionedProtocol; @@ -59,8 +62,7 @@ public class TestServerCustomProtocol { /* Test protocol implementation */ public static class PingHandler implements Coprocessor, PingProtocol, VersionedProtocol { - static int VERSION = 1; - + static long VERSION = 1; private int counter = 0; @Override public String ping() { @@ -90,6 +92,13 @@ public class TestServerCustomProtocol { } @Override + public ProtocolSignature getProtocolSignature( + String protocol, long version, int clientMethodsHashCode) + throws IOException { + return new ProtocolSignature(VERSION, null); + } + + @Override public long getProtocolVersion(String s, long l) throws IOException { return VERSION; }