diff --git a/hbase-client/src/main/java/org/apache/hadoop/hbase/ipc/BlockingRpcConnection.java b/hbase-client/src/main/java/org/apache/hadoop/hbase/ipc/BlockingRpcConnection.java index c8b366d..7635070 100644 --- a/hbase-client/src/main/java/org/apache/hadoop/hbase/ipc/BlockingRpcConnection.java +++ b/hbase-client/src/main/java/org/apache/hadoop/hbase/ipc/BlockingRpcConnection.java @@ -66,6 +66,7 @@ import org.apache.hadoop.hbase.protobuf.generated.RPCProtos.ExceptionResponse; import org.apache.hadoop.hbase.protobuf.generated.RPCProtos.RequestHeader; import org.apache.hadoop.hbase.protobuf.generated.RPCProtos.ResponseHeader; import org.apache.hadoop.hbase.security.HBaseSaslRpcClient; +import org.apache.hadoop.hbase.security.SaslUtil; import org.apache.hadoop.hbase.security.SaslUtil.QualityOfProtection; import org.apache.hadoop.hbase.util.EnvironmentEdgeManager; import org.apache.hadoop.hbase.util.ExceptionUtil; @@ -350,7 +351,7 @@ class BlockingRpcConnection extends RpcConnection implements Runnable { saslRpcClient = new HBaseSaslRpcClient(authMethod, token, serverPrincipal, this.rpcClient.fallbackAllowed, this.rpcClient.conf.get("hbase.rpc.protection", QualityOfProtection.AUTHENTICATION.name().toLowerCase(Locale.ROOT))); - return saslRpcClient.saslConnect(in2, out2); + return saslRpcClient.saslConnect(this.rpcClient.conf, in2, out2); } /** diff --git a/hbase-client/src/main/java/org/apache/hadoop/hbase/security/HBaseSaslRpcClient.java b/hbase-client/src/main/java/org/apache/hadoop/hbase/security/HBaseSaslRpcClient.java index 3f43f7f..be5bab5 100644 --- a/hbase-client/src/main/java/org/apache/hadoop/hbase/security/HBaseSaslRpcClient.java +++ b/hbase-client/src/main/java/org/apache/hadoop/hbase/security/HBaseSaslRpcClient.java @@ -31,11 +31,14 @@ import javax.security.sasl.SaslException; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; +import org.apache.hadoop.conf.Configuration; +import org.apache.hadoop.crypto.CipherOption; +import org.apache.hadoop.crypto.CipherSuite; import org.apache.hadoop.hbase.classification.InterfaceAudience; import org.apache.hadoop.io.WritableUtils; import org.apache.hadoop.ipc.RemoteException; -import org.apache.hadoop.security.SaslInputStream; -import org.apache.hadoop.security.SaslOutputStream; +import org.apache.hadoop.hbase.security.SaslInputStream; +import org.apache.hadoop.hbase.security.SaslOutputStream; import org.apache.hadoop.security.token.Token; import org.apache.hadoop.security.token.TokenIdentifier; @@ -48,6 +51,9 @@ public class HBaseSaslRpcClient extends AbstractHBaseSaslRpcClient { private static final Log LOG = LogFactory.getLog(HBaseSaslRpcClient.class); + private boolean useNegotiatedCipher; + private SaslCryptoCodec saslCodec; + public HBaseSaslRpcClient(AuthMethod method, Token token, String serverPrincipal, boolean fallbackAllowed) throws IOException { super(method, token, serverPrincipal, fallbackAllowed); @@ -73,13 +79,23 @@ public class HBaseSaslRpcClient extends AbstractHBaseSaslRpcClient { * @return true if connection is set up, or false if needs to switch to simple Auth. * @throws IOException */ - public boolean saslConnect(InputStream inS, OutputStream outS) throws IOException { + public boolean saslConnect(Configuration conf, InputStream inS, OutputStream outS) throws IOException { + String cipherSuites = conf.get(SaslUtil.HBASE_RPC_SECURITY_CRYPTO_CIPHER_SUITES); + useNegotiatedCipher = SaslUtil.requestedQopContainsPrivacy(saslProps) && + cipherSuites != null && !cipherSuites.isEmpty(); + DataInputStream inStream = new DataInputStream(new BufferedInputStream(inS)); DataOutputStream outStream = new DataOutputStream(new BufferedOutputStream(outS)); try { byte[] saslToken = getInitialResponse(); if (saslToken != null) { + if (useNegotiatedCipher) { + if (LOG.isDebugEnabled()) + LOG.debug("Will send client ciphers: " + cipherSuites); + outStream.writeInt(SaslUtil.USE_NEGOTIATED_CIPHER); + WritableUtils.writeString(outStream, cipherSuites); + } outStream.writeInt(saslToken.length); outStream.write(saslToken, 0, saslToken.length); outStream.flush(); @@ -109,6 +125,7 @@ public class HBaseSaslRpcClient extends AbstractHBaseSaslRpcClient { inStream.readFully(saslToken); } + CipherOption cipherOption = null; while (!isComplete()) { saslToken = evaluateChallenge(saslToken); if (saslToken != null) { @@ -127,8 +144,40 @@ public class HBaseSaslRpcClient extends AbstractHBaseSaslRpcClient { + " for processing by initSASLContext"); } inStream.readFully(saslToken); + } else if (useNegotiatedCipher && SaslUtil.isNegotiatedQopPrivacy(saslClient)) { + readStatus(inStream); + int len = inStream.readInt(); + if (len == SaslUtil.USE_NEGOTIATED_CIPHER) { + String cipherName = WritableUtils.readString(inStream); + byte[] inKey = new byte[inStream.readInt()]; + inStream.readFully(inKey); + byte[] inIv = new byte[inStream.readInt()]; + inStream.readFully(inIv); + byte[] outKey = new byte[inStream.readInt()]; + inStream.readFully(outKey); + byte[] outIv = new byte[inStream.readInt()]; + inStream.readFully(outIv); + CipherOption wrappedCipherOption = new CipherOption(CipherSuite.convert(cipherName), + inKey, inIv, outKey, outIv); + // Unwrap the negotiated cipher option + cipherOption = SaslUtil.unwrap(wrappedCipherOption, saslClient); + } else if (len != 0) { + LOG.warn("Have read unexpected input token of size " + len + + " after client is complete"); + } + if (LOG.isDebugEnabled()) { + if (cipherOption == null) { + LOG.debug("Client not using any cipher suite"); + } else { + LOG.debug("Client using cipher suite " + + cipherOption.getCipherSuite().getName() + " with server"); + } + } } } + if (cipherOption != null) { + saslCodec = new SaslCryptoCodec(conf, cipherOption, false); + } if (LOG.isDebugEnabled()) { LOG.debug("SASL client context established. Negotiated QoP: " + saslClient.getNegotiatedProperty(Sasl.QOP)); @@ -154,7 +203,7 @@ public class HBaseSaslRpcClient extends AbstractHBaseSaslRpcClient { if (!saslClient.isComplete()) { throw new IOException("Sasl authentication exchange hasn't completed yet"); } - return new SaslInputStream(in, saslClient); + return saslCodec != null ? new SaslInputStream(in, saslCodec) : new SaslInputStream(in, saslClient); } /** @@ -167,6 +216,6 @@ public class HBaseSaslRpcClient extends AbstractHBaseSaslRpcClient { if (!saslClient.isComplete()) { throw new IOException("Sasl authentication exchange hasn't completed yet"); } - return new SaslOutputStream(out, saslClient); + return saslCodec != null ? new SaslOutputStream(out, saslClient) : new SaslOutputStream(out, saslCodec); } } diff --git a/hbase-client/src/main/java/org/apache/hadoop/hbase/security/SaslUtil.java b/hbase-client/src/main/java/org/apache/hadoop/hbase/security/SaslUtil.java index aaa9d7a..ba03542 100644 --- a/hbase-client/src/main/java/org/apache/hadoop/hbase/security/SaslUtil.java +++ b/hbase-client/src/main/java/org/apache/hadoop/hbase/security/SaslUtil.java @@ -18,23 +18,45 @@ */ package org.apache.hadoop.hbase.security; +import java.io.IOException; +import java.util.Arrays; +import java.util.List; import java.util.Map; +import java.util.Set; import java.util.TreeMap; import javax.security.sasl.Sasl; import javax.security.sasl.SaslClient; import javax.security.sasl.SaslException; +import javax.security.sasl.SaslServer; +import com.google.common.base.Splitter; +import com.google.common.collect.ImmutableSet; +import com.google.common.collect.Lists; import org.apache.commons.codec.binary.Base64; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; +import org.apache.hadoop.conf.Configuration; +import org.apache.hadoop.crypto.CipherOption; +import org.apache.hadoop.crypto.CipherSuite; +import org.apache.hadoop.crypto.CryptoCodec; import org.apache.hadoop.hbase.classification.InterfaceAudience; +import org.apache.hadoop.security.SaslRpcServer; @InterfaceAudience.Private public class SaslUtil { private static final Log LOG = LogFactory.getLog(SaslUtil.class); + + public static final String HBASE_RPC_SECURITY_CRYPTO_CIPHER_SUITES = + "hbase.rpc.security.crypto.cipher.suites"; + public static final String HBASE_RPC_SECURITY_CRYPTO_CIPHER_KEY_BITLENGTH_KEY = + "hbase.rpc.security.crypto.cipher.key.bitlength"; + public static final int HBASE_RPC_SECURITY_CRYPTO_CIPHER_KEY_BITLENGTH_DEFAULT = + 128; + public static final String SASL_DEFAULT_REALM = "default"; public static final int SWITCH_TO_SIMPLE_AUTH = -88; + public static final int USE_NEGOTIATED_CIPHER = -89; public enum QualityOfProtection { AUTHENTICATION("auth"), @@ -123,4 +145,159 @@ public class SaslUtil { LOG.error("Error disposing of SASL client", e); } } + + /** + * Check whether requested SASL Qop contains privacy. + * + * @param saslProps properties of SASL negotiation + * @return boolean true if privacy exists + */ + public static boolean requestedQopContainsPrivacy( + Map saslProps) { + Set requestedQop = ImmutableSet.copyOf(Arrays.asList( + saslProps.get(Sasl.QOP).split(","))); + return requestedQop.contains( + SaslRpcServer.QualityOfProtection.PRIVACY.getSaslQop()); + } + + /** + * After successful SASL negotiation, returns whether it's QOP privacy. + * + * @return boolean whether it's QOP privacy + */ + public static boolean isNegotiatedQopPrivacy(SaslServer saslServer) { + String qop = (String) saslServer.getNegotiatedProperty(Sasl.QOP); + return qop != null && SaslUtil.QualityOfProtection.PRIVACY + .getSaslQop().equalsIgnoreCase(qop); + } + + /** + * After successful SASL negotiation, returns whether it's QOP privacy. + * + * @return boolean whether it's QOP privacy + */ + public static boolean isNegotiatedQopPrivacy(SaslClient saslClient) { + String qop = (String) saslClient.getNegotiatedProperty(Sasl.QOP); + return qop != null && SaslUtil.QualityOfProtection.PRIVACY + .getSaslQop().equalsIgnoreCase(qop); + } + + /** + * Negotiate a cipher option which server supports. + * + * @param conf the configuration + * @param options the cipher options which client supports + * @return CipherOption negotiated cipher option + */ + public static CipherOption negotiateCipherOption(Configuration conf, + List options) throws IOException { + // Negotiate cipher suites if configured. Currently, the only supported + // cipher suite is AES/CTR/NoPadding, but the protocol allows multiple + // values for future expansion. + String cipherSuites = conf.get(HBASE_RPC_SECURITY_CRYPTO_CIPHER_SUITES); + if (cipherSuites == null || cipherSuites.isEmpty()) { + return null; + } + if (!cipherSuites.equals(CipherSuite.AES_CTR_NOPADDING.getName())) { + throw new IOException(String.format("Invalid cipher suite, %s=%s", + HBASE_RPC_SECURITY_CRYPTO_CIPHER_SUITES, cipherSuites)); + } + if (options != null) { + for (CipherOption option : options) { + CipherSuite suite = option.getCipherSuite(); + if (suite == CipherSuite.AES_CTR_NOPADDING) { + int keyLen = conf.getInt( + HBASE_RPC_SECURITY_CRYPTO_CIPHER_KEY_BITLENGTH_KEY, + HBASE_RPC_SECURITY_CRYPTO_CIPHER_KEY_BITLENGTH_DEFAULT) / 8; + CryptoCodec codec = CryptoCodec.getInstance(conf, suite); + byte[] inKey = new byte[keyLen]; + byte[] inIv = new byte[suite.getAlgorithmBlockSize()]; + byte[] outKey = new byte[keyLen]; + byte[] outIv = new byte[suite.getAlgorithmBlockSize()]; + assert codec != null; + codec.generateSecureRandom(inKey); + codec.generateSecureRandom(inIv); + codec.generateSecureRandom(outKey); + codec.generateSecureRandom(outIv); + return new CipherOption(suite, inKey, inIv, outKey, outIv); + } + } + } + return null; + } + + /** + * Encrypt the key of the negotiated cipher option. + * + * @param option negotiated cipher option + * @param saslServer SASL server + * @return CipherOption negotiated cipher option which contains the + * encrypted key and iv + * @throws IOException for any error + */ + public static CipherOption wrap(CipherOption option, SaslServer saslServer) + throws IOException { + if (option != null) { + byte[] inKey = option.getInKey(); + if (inKey != null) { + inKey = saslServer.wrap(inKey, 0, inKey.length); + } + byte[] outKey = option.getOutKey(); + if (outKey != null) { + outKey = saslServer.wrap(outKey, 0, outKey.length); + } + return new CipherOption(option.getCipherSuite(), inKey, option.getInIv(), + outKey, option.getOutIv()); + } + + return null; + } + + /** + * Decrypt the key of the negotiated cipher option. + * + * @param option negotiated cipher option + * @param saslClient SASL client + * @return CipherOption negotiated cipher option which contains the + * decrypted key and iv + * @throws IOException for any error + */ + public static CipherOption unwrap(CipherOption option, SaslClient saslClient) + throws IOException { + if (option != null) { + byte[] inKey = option.getInKey(); + if (inKey != null) { + inKey = saslClient.unwrap(inKey, 0, inKey.length); + } + byte[] outKey = option.getOutKey(); + if (outKey != null) { + outKey = saslClient.unwrap(outKey, 0, outKey.length); + } + return new CipherOption(option.getCipherSuite(), inKey, option.getInIv(), + outKey, option.getOutIv()); + } + + return null; + } + + /** + * Read the cipher options from the given string. + * + * @param cipherSuites the ciphers as a string + * @return List of the cipher options + */ + public static List getCipherOptions(String cipherSuites) + throws IOException { + List cipherOptions = null; + if (cipherSuites != null && !cipherSuites.isEmpty()) { + cipherOptions = Lists.newArrayListWithCapacity(1); + for (String cipherSuite : Splitter.on(',').trimResults(). + omitEmptyStrings().split(cipherSuites)) { + CipherOption option = new CipherOption( + CipherSuite.convert(cipherSuite)); + cipherOptions.add(option); + } + } + return cipherOptions; + } } diff --git a/hbase-common/src/main/java/org/apache/hadoop/hbase/crypto/CryptoInputStream.java b/hbase-common/src/main/java/org/apache/hadoop/hbase/crypto/CryptoInputStream.java new file mode 100644 index 0000000..5832881 --- /dev/null +++ b/hbase-common/src/main/java/org/apache/hadoop/hbase/crypto/CryptoInputStream.java @@ -0,0 +1,740 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.hadoop.hbase.crypto; + +import java.io.FileDescriptor; +import java.io.FileInputStream; +import java.io.FilterInputStream; +import java.io.IOException; +import java.io.InputStream; +import java.nio.ByteBuffer; +import java.nio.channels.ReadableByteChannel; +import java.security.GeneralSecurityException; +import java.util.EnumSet; +import java.util.Queue; +import java.util.concurrent.ConcurrentLinkedQueue; + +import org.apache.hadoop.classification.InterfaceAudience; +import org.apache.hadoop.classification.InterfaceStability; +import org.apache.hadoop.crypto.*; +import org.apache.hadoop.fs.ByteBufferReadable; +import org.apache.hadoop.fs.CanSetDropBehind; +import org.apache.hadoop.fs.CanSetReadahead; +import org.apache.hadoop.fs.HasEnhancedByteBufferAccess; +import org.apache.hadoop.fs.HasFileDescriptor; +import org.apache.hadoop.fs.PositionedReadable; +import org.apache.hadoop.fs.ReadOption; +import org.apache.hadoop.fs.Seekable; +import org.apache.hadoop.io.ByteBufferPool; + +import com.google.common.base.Preconditions; + +/** + * CryptoInputStream decrypts data. It is not thread-safe. AES CTR mode is + * required in order to ensure that the plain text and cipher text have a 1:1 + * mapping. The decryption is buffer based. The key points of the decryption + * are (1) calculating the counter and (2) padding through stream position: + *

+ * counter = base + pos/(algorithm blocksize); + * padding = pos%(algorithm blocksize); + *

+ * The underlying stream offset is maintained as state. + */ +@InterfaceAudience.Private +@InterfaceStability.Evolving +public class CryptoInputStream extends FilterInputStream implements + Seekable, PositionedReadable, ByteBufferReadable, HasFileDescriptor, + CanSetDropBehind, CanSetReadahead, HasEnhancedByteBufferAccess, + ReadableByteChannel { + private final byte[] oneByteBuf = new byte[1]; + private final CryptoCodec codec; + private final Decryptor decryptor; + private final int bufferSize; + + /** + * Input data buffer. The data starts at inBuffer.position() and ends at + * to inBuffer.limit(). + */ + private ByteBuffer inBuffer; + + /** + * The decrypted data buffer. The data starts at outBuffer.position() and + * ends at outBuffer.limit(); + */ + private ByteBuffer outBuffer; + private long streamOffset = 0; // Underlying stream offset. + + /** + * Whether the underlying stream supports + * {@link org.apache.hadoop.fs.ByteBufferReadable} + */ + private Boolean usingByteBufferRead = null; + + /** + * Padding = pos%(algorithm blocksize); Padding is put into {@link #inBuffer} + * before any other data goes in. The purpose of padding is to put the input + * data at proper position. + */ + private byte padding; + private boolean closed; + private final byte[] key; + private final byte[] initIV; + private byte[] iv; + private boolean isByteBufferReadable; + private boolean isReadableByteChannel; + + /** DirectBuffer pool */ + private final Queue bufferPool = + new ConcurrentLinkedQueue(); + /** Decryptor pool */ + private final Queue decryptorPool = + new ConcurrentLinkedQueue(); + + public CryptoInputStream(InputStream in, CryptoCodec codec, + int bufferSize, byte[] key, byte[] iv) throws IOException { + this(in, codec, bufferSize, key, iv, + in != null ? CryptoStreamUtils.getInputStreamOffset(in) : 0); + } + + public CryptoInputStream(InputStream in, CryptoCodec codec, + int bufferSize, byte[] key, byte[] iv, long streamOffset) throws IOException { + super(in); + CryptoStreamUtils.checkCodec(codec); + this.bufferSize = CryptoStreamUtils.checkBufferSize(codec, bufferSize); + this.codec = codec; + this.key = key.clone(); + this.initIV = iv.clone(); + this.iv = iv.clone(); + this.streamOffset = streamOffset; + if (in != null) { + isByteBufferReadable = in instanceof ByteBufferReadable; + isReadableByteChannel = in instanceof ReadableByteChannel; + } + inBuffer = ByteBuffer.allocateDirect(this.bufferSize); + outBuffer = ByteBuffer.allocateDirect(this.bufferSize); + decryptor = getDecryptor(); + resetStreamOffset(streamOffset); + } + + public CryptoInputStream(InputStream in, CryptoCodec codec, + byte[] key, byte[] iv) throws IOException { + this(in, codec, CryptoStreamUtils.getBufferSize(codec.getConf()), key, iv); + } + + public InputStream getWrappedStream() { + return in; + } + + public void setWrappedStream(InputStream in) { + this.in = in; + if (in != null) { + isByteBufferReadable = in instanceof ByteBufferReadable; + isReadableByteChannel = in instanceof ReadableByteChannel; + usingByteBufferRead = null; + } + } + + /** + * Decryption is buffer based. + * If there is data in {@link #outBuffer}, then read it out of this buffer. + * If there is no data in {@link #outBuffer}, then read more from the + * underlying stream and do the decryption. + * @param b the buffer into which the decrypted data is read. + * @param off the buffer offset. + * @param len the maximum number of decrypted data bytes to read. + * @return int the total number of decrypted data bytes read into the buffer. + * @throws IOException + */ + @Override + public int read(byte[] b, int off, int len) throws IOException { + checkStream(); + if (b == null) { + throw new NullPointerException(); + } else if (off < 0 || len < 0 || len > b.length - off) { + throw new IndexOutOfBoundsException(); + } else if (len == 0) { + return 0; + } + + final int remaining = outBuffer.remaining(); + if (remaining > 0) { + int n = Math.min(len, remaining); + outBuffer.get(b, off, n); + return n; + } else { + int n = 0; + + /* + * Check whether the underlying stream is {@link ByteBufferReadable}, + * it can avoid bytes copy. + */ + if (usingByteBufferRead == null) { + if (isByteBufferReadable || isReadableByteChannel) { + try { + n = isByteBufferReadable ? + ((ByteBufferReadable) in).read(inBuffer) : + ((ReadableByteChannel) in).read(inBuffer); + usingByteBufferRead = Boolean.TRUE; + } catch (UnsupportedOperationException e) { + usingByteBufferRead = Boolean.FALSE; + } + } else { + usingByteBufferRead = Boolean.FALSE; + } + if (!usingByteBufferRead) { + n = readFromUnderlyingStream(inBuffer); + } + } else { + if (usingByteBufferRead) { + n = isByteBufferReadable ? ((ByteBufferReadable) in).read(inBuffer) : + ((ReadableByteChannel) in).read(inBuffer); + } else { + n = readFromUnderlyingStream(inBuffer); + } + } + if (n <= 0) { + return n; + } + + streamOffset += n; // Read n bytes + decrypt(decryptor, inBuffer, outBuffer, padding); + padding = afterDecryption(decryptor, inBuffer, streamOffset, iv); + n = Math.min(len, outBuffer.remaining()); + outBuffer.get(b, off, n); + return n; + } + } + + public void readFully(byte[] b, int off, int len) throws IOException { + int read = 0; + while (read < len) { + int n = read(b, off + read, len - read); + if (n < 0) { + throw new IOException("End of file reached before reading fully."); + } + read += n; + } + } + + /** Read data from underlying stream. */ + private int readFromUnderlyingStream(ByteBuffer inBuffer) throws IOException { + final int toRead = inBuffer.remaining(); + final byte[] tmp = getTmpBuf(); + final int n = in.read(tmp, 0, toRead); + if (n > 0) { + inBuffer.put(tmp, 0, n); + } + return n; + } + + private byte[] tmpBuf; + private byte[] getTmpBuf() { + if (tmpBuf == null) { + tmpBuf = new byte[bufferSize]; + } + return tmpBuf; + } + + /** + * Do the decryption using inBuffer as input and outBuffer as output. + * Upon return, inBuffer is cleared; the decrypted data starts at + * outBuffer.position() and ends at outBuffer.limit(); + */ + private void decrypt(Decryptor decryptor, ByteBuffer inBuffer, + ByteBuffer outBuffer, byte padding) throws IOException { + Preconditions.checkState(inBuffer.position() >= padding); + if(inBuffer.position() == padding) { + // There is no real data in inBuffer. + return; + } + inBuffer.flip(); + outBuffer.clear(); + decryptor.decrypt(inBuffer, outBuffer); + inBuffer.clear(); + outBuffer.flip(); + if (padding > 0) { + /* + * The plain text and cipher text have a 1:1 mapping, they start at the + * same position. + */ + outBuffer.position(padding); + } + } + + /** + * This method is executed immediately after decryption. Check whether + * decryptor should be updated and recalculate padding if needed. + */ + private byte afterDecryption(Decryptor decryptor, ByteBuffer inBuffer, + long position, byte[] iv) throws IOException { + byte padding = 0; + if (decryptor.isContextReset()) { + /* + * This code is generally not executed since the decryptor usually + * maintains decryption context (e.g. the counter) internally. However, + * some implementations can't maintain context so a re-init is necessary + * after each decryption call. + */ + updateDecryptor(decryptor, position, iv); + padding = getPadding(position); + inBuffer.position(padding); + } + return padding; + } + + private long getCounter(long position) { + return position / codec.getCipherSuite().getAlgorithmBlockSize(); + } + + private byte getPadding(long position) { + return (byte)(position % codec.getCipherSuite().getAlgorithmBlockSize()); + } + + /** Calculate the counter and iv, update the decryptor. */ + private void updateDecryptor(Decryptor decryptor, long position, byte[] iv) + throws IOException { + final long counter = getCounter(position); + codec.calculateIV(initIV, counter, iv); + decryptor.init(key, iv); + } + + /** + * Reset the underlying stream offset; clear {@link #inBuffer} and + * {@link #outBuffer}. This Typically happens during {@link #seek(long)} + * or {@link #skip(long)}. + */ + private void resetStreamOffset(long offset) throws IOException { + streamOffset = offset; + inBuffer.clear(); + outBuffer.clear(); + outBuffer.limit(0); + updateDecryptor(decryptor, offset, iv); + padding = getPadding(offset); + inBuffer.position(padding); // Set proper position for input data. + } + + @Override + public void close() throws IOException { + if (closed) { + return; + } + + super.close(); + freeBuffers(); + closed = true; + } + + /** Positioned read. It is thread-safe */ + @Override + public int read(long position, byte[] buffer, int offset, int length) + throws IOException { + checkStream(); + try { + final int n = ((PositionedReadable) in).read(position, buffer, offset, + length); + if (n > 0) { + // This operation does not change the current offset of the file + decrypt(position, buffer, offset, n); + } + + return n; + } catch (ClassCastException e) { + throw new UnsupportedOperationException("This stream does not support " + + "positioned read."); + } + } + + /** + * Decrypt length bytes in buffer starting at offset. Output is also put + * into buffer starting at offset. It is thread-safe. + */ + private void decrypt(long position, byte[] buffer, int offset, int length) + throws IOException { + ByteBuffer inBuffer = getBuffer(); + ByteBuffer outBuffer = getBuffer(); + Decryptor decryptor = null; + try { + decryptor = getDecryptor(); + byte[] iv = initIV.clone(); + updateDecryptor(decryptor, position, iv); + byte padding = getPadding(position); + inBuffer.position(padding); // Set proper position for input data. + + int n = 0; + while (n < length) { + int toDecrypt = Math.min(length - n, inBuffer.remaining()); + inBuffer.put(buffer, offset + n, toDecrypt); + // Do decryption + decrypt(decryptor, inBuffer, outBuffer, padding); + + outBuffer.get(buffer, offset + n, toDecrypt); + n += toDecrypt; + padding = afterDecryption(decryptor, inBuffer, position + n, iv); + } + } finally { + returnBuffer(inBuffer); + returnBuffer(outBuffer); + returnDecryptor(decryptor); + } + } + + /** Positioned read fully. It is thread-safe */ + @Override + public void readFully(long position, byte[] buffer, int offset, int length) + throws IOException { + checkStream(); + try { + ((PositionedReadable) in).readFully(position, buffer, offset, length); + if (length > 0) { + // This operation does not change the current offset of the file + decrypt(position, buffer, offset, length); + } + } catch (ClassCastException e) { + throw new UnsupportedOperationException("This stream does not support " + + "positioned readFully."); + } + } + + @Override + public void readFully(long position, byte[] buffer) throws IOException { + readFully(position, buffer, 0, buffer.length); + } + + /** Seek to a position. */ + @Override + public void seek(long pos) throws IOException { + Preconditions.checkArgument(pos >= 0, "Cannot seek to negative offset."); + checkStream(); + try { + /* + * If data of target pos in the underlying stream has already been read + * and decrypted in outBuffer, we just need to re-position outBuffer. + */ + if (pos <= streamOffset && pos >= (streamOffset - outBuffer.remaining())) { + int forward = (int) (pos - (streamOffset - outBuffer.remaining())); + if (forward > 0) { + outBuffer.position(outBuffer.position() + forward); + } + } else { + ((Seekable) in).seek(pos); + resetStreamOffset(pos); + } + } catch (ClassCastException e) { + throw new UnsupportedOperationException("This stream does not support " + + "seek."); + } + } + + /** Skip n bytes */ + @Override + public long skip(long n) throws IOException { + Preconditions.checkArgument(n >= 0, "Negative skip length."); + checkStream(); + + if (n == 0) { + return 0; + } else if (n <= outBuffer.remaining()) { + int pos = outBuffer.position() + (int) n; + outBuffer.position(pos); + return n; + } else { + /* + * Subtract outBuffer.remaining() to see how many bytes we need to + * skip in the underlying stream. Add outBuffer.remaining() to the + * actual number of skipped bytes in the underlying stream to get the + * number of skipped bytes from the user's point of view. + */ + n -= outBuffer.remaining(); + long skipped = in.skip(n); + if (skipped < 0) { + skipped = 0; + } + long pos = streamOffset + skipped; + skipped += outBuffer.remaining(); + resetStreamOffset(pos); + return skipped; + } + } + + /** Get underlying stream position. */ + @Override + public long getPos() throws IOException { + checkStream(); + // Equals: ((Seekable) in).getPos() - outBuffer.remaining() + return streamOffset - outBuffer.remaining(); + } + + /** ByteBuffer read. */ + @Override + public int read(ByteBuffer buf) throws IOException { + checkStream(); + if (isByteBufferReadable || isReadableByteChannel) { + final int unread = outBuffer.remaining(); + if (unread > 0) { // Have unread decrypted data in buffer. + int toRead = buf.remaining(); + if (toRead <= unread) { + final int limit = outBuffer.limit(); + outBuffer.limit(outBuffer.position() + toRead); + buf.put(outBuffer); + outBuffer.limit(limit); + return toRead; + } else { + buf.put(outBuffer); + } + } + + final int pos = buf.position(); + final int n = isByteBufferReadable ? ((ByteBufferReadable) in).read(buf) : + ((ReadableByteChannel) in).read(buf); + if (n > 0) { + streamOffset += n; // Read n bytes + decrypt(buf, n, pos); + } + + if (n >= 0) { + return unread + n; + } else { + if (unread == 0) { + return -1; + } else { + return unread; + } + } + } else { + int n = 0; + if (buf.hasArray()) { + n = read(buf.array(), buf.position(), buf.remaining()); + if (n > 0) { + buf.position(buf.position() + n); + } + } else { + byte[] tmp = new byte[buf.remaining()]; + n = read(tmp); + if (n > 0) { + buf.put(tmp, 0, n); + } + } + return n; + } + } + + /** + * Decrypt all data in buf: total n bytes from given start position. + * Output is also buf and same start position. + * buf.position() and buf.limit() should be unchanged after decryption. + */ + private void decrypt(ByteBuffer buf, int n, int start) + throws IOException { + final int pos = buf.position(); + final int limit = buf.limit(); + int len = 0; + while (len < n) { + buf.position(start + len); + buf.limit(start + len + Math.min(n - len, inBuffer.remaining())); + inBuffer.put(buf); + // Do decryption + try { + decrypt(decryptor, inBuffer, outBuffer, padding); + buf.position(start + len); + buf.limit(limit); + len += outBuffer.remaining(); + buf.put(outBuffer); + } finally { + padding = afterDecryption(decryptor, inBuffer, streamOffset - (n - len), iv); + } + } + buf.position(pos); + } + + @Override + public int available() throws IOException { + checkStream(); + + return in.available() + outBuffer.remaining(); + } + + @Override + public boolean markSupported() { + return false; + } + + @Override + public void mark(int readLimit) { + } + + @Override + public void reset() throws IOException { + throw new IOException("Mark/reset not supported"); + } + + @Override + public boolean seekToNewSource(long targetPos) throws IOException { + Preconditions.checkArgument(targetPos >= 0, + "Cannot seek to negative offset."); + checkStream(); + try { + boolean result = ((Seekable) in).seekToNewSource(targetPos); + resetStreamOffset(targetPos); + return result; + } catch (ClassCastException e) { + throw new UnsupportedOperationException("This stream does not support " + + "seekToNewSource."); + } + } + + @Override + public ByteBuffer read(ByteBufferPool bufferPool, int maxLength, + EnumSet opts) throws IOException, + UnsupportedOperationException { + checkStream(); + try { + if (outBuffer.remaining() > 0) { + // Have some decrypted data unread, need to reset. + ((Seekable) in).seek(getPos()); + resetStreamOffset(getPos()); + } + final ByteBuffer buffer = ((HasEnhancedByteBufferAccess) in). + read(bufferPool, maxLength, opts); + if (buffer != null) { + final int n = buffer.remaining(); + if (n > 0) { + streamOffset += buffer.remaining(); // Read n bytes + final int pos = buffer.position(); + decrypt(buffer, n, pos); + } + } + return buffer; + } catch (ClassCastException e) { + throw new UnsupportedOperationException("This stream does not support " + + "enhanced byte buffer access."); + } + } + + @Override + public void releaseBuffer(ByteBuffer buffer) { + try { + ((HasEnhancedByteBufferAccess) in).releaseBuffer(buffer); + } catch (ClassCastException e) { + throw new UnsupportedOperationException("This stream does not support " + + "release buffer."); + } + } + + @Override + public void setReadahead(Long readahead) throws IOException, + UnsupportedOperationException { + try { + ((CanSetReadahead) in).setReadahead(readahead); + } catch (ClassCastException e) { + throw new UnsupportedOperationException("This stream does not support " + + "setting the readahead caching strategy."); + } + } + + @Override + public void setDropBehind(Boolean dropCache) throws IOException, + UnsupportedOperationException { + try { + ((CanSetDropBehind) in).setDropBehind(dropCache); + } catch (ClassCastException e) { + throw new UnsupportedOperationException("This stream does not " + + "support setting the drop-behind caching setting."); + } + } + + @Override + public FileDescriptor getFileDescriptor() throws IOException { + if (in instanceof HasFileDescriptor) { + return ((HasFileDescriptor) in).getFileDescriptor(); + } else if (in instanceof FileInputStream) { + return ((FileInputStream) in).getFD(); + } else { + return null; + } + } + + @Override + public int read() throws IOException { + return (read(oneByteBuf, 0, 1) == -1) ? -1 : (oneByteBuf[0] & 0xff); + } + + private void checkStream() throws IOException { + if (closed) { + throw new IOException("Stream closed"); + } + } + + /** Get direct buffer from pool */ + private ByteBuffer getBuffer() { + ByteBuffer buffer = bufferPool.poll(); + if (buffer == null) { + buffer = ByteBuffer.allocateDirect(bufferSize); + } + + return buffer; + } + + /** Return direct buffer to pool */ + private void returnBuffer(ByteBuffer buf) { + if (buf != null) { + buf.clear(); + bufferPool.add(buf); + } + } + + /** Forcibly free the direct buffers. */ + private void freeBuffers() { + CryptoStreamUtils.freeDB(inBuffer); + CryptoStreamUtils.freeDB(outBuffer); + cleanBufferPool(); + } + + /** Clean direct buffer pool */ + private void cleanBufferPool() { + ByteBuffer buf; + while ((buf = bufferPool.poll()) != null) { + CryptoStreamUtils.freeDB(buf); + } + } + + /** Get decryptor from pool */ + private Decryptor getDecryptor() throws IOException { + Decryptor decryptor = decryptorPool.poll(); + if (decryptor == null) { + try { + decryptor = codec.createDecryptor(); + } catch (GeneralSecurityException e) { + throw new IOException(e); + } + } + + return decryptor; + } + + /** Return decryptor to pool */ + private void returnDecryptor(Decryptor decryptor) { + if (decryptor != null) { + decryptorPool.add(decryptor); + } + } + + @Override + public boolean isOpen() { + return !closed; + } +} diff --git a/hbase-common/src/main/java/org/apache/hadoop/hbase/crypto/CryptoOutputStream.java b/hbase-common/src/main/java/org/apache/hadoop/hbase/crypto/CryptoOutputStream.java new file mode 100644 index 0000000..b4d3cd1 --- /dev/null +++ b/hbase-common/src/main/java/org/apache/hadoop/hbase/crypto/CryptoOutputStream.java @@ -0,0 +1,296 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.hadoop.hbase.crypto; + +import java.io.FilterOutputStream; +import java.io.IOException; +import java.io.OutputStream; +import java.nio.ByteBuffer; +import java.security.GeneralSecurityException; + +import org.apache.hadoop.classification.InterfaceAudience; +import org.apache.hadoop.classification.InterfaceStability; +import org.apache.hadoop.crypto.*; +import org.apache.hadoop.fs.CanSetDropBehind; +import org.apache.hadoop.fs.Syncable; + +import com.google.common.base.Preconditions; + +/** + * CryptoOutputStream encrypts data. It is not thread-safe. AES CTR mode is + * required in order to ensure that the plain text and cipher text have a 1:1 + * mapping. The encryption is buffer based. The key points of the encryption are + * (1) calculating counter and (2) padding through stream position. + *

+ * counter = base + pos/(algorithm blocksize); + * padding = pos%(algorithm blocksize); + *

+ * The underlying stream offset is maintained as state. + * + * Note that while some of this class' methods are synchronized, this is just to + * match the threadsafety behavior of DFSOutputStream. See HADOOP-11710. + */ +@InterfaceAudience.Private +@InterfaceStability.Evolving +public class CryptoOutputStream extends FilterOutputStream implements + Syncable, CanSetDropBehind { + private final byte[] oneByteBuf = new byte[1]; + private final CryptoCodec codec; + private final Encryptor encryptor; + private final int bufferSize; + + /** + * Input data buffer. The data starts at inBuffer.position() and ends at + * inBuffer.limit(). + */ + private ByteBuffer inBuffer; + + /** + * Encrypted data buffer. The data starts at outBuffer.position() and ends at + * outBuffer.limit(); + */ + private ByteBuffer outBuffer; + private long streamOffset = 0; // Underlying stream offset. + + /** + * Padding = pos%(algorithm blocksize); Padding is put into {@link #inBuffer} + * before any other data goes in. The purpose of padding is to put input data + * at proper position. + */ + private byte padding; + private boolean closed; + private final byte[] key; + private final byte[] initIV; + private byte[] iv; + + public CryptoOutputStream(OutputStream out, CryptoCodec codec, + int bufferSize, byte[] key, byte[] iv) throws IOException { + this(out, codec, bufferSize, key, iv, 0); + } + + public CryptoOutputStream(OutputStream out, CryptoCodec codec, + int bufferSize, byte[] key, byte[] iv, long streamOffset) + throws IOException { + super(out); + CryptoStreamUtils.checkCodec(codec); + this.bufferSize = CryptoStreamUtils.checkBufferSize(codec, bufferSize); + this.codec = codec; + this.key = key.clone(); + this.initIV = iv.clone(); + this.iv = iv.clone(); + inBuffer = ByteBuffer.allocateDirect(this.bufferSize); + outBuffer = ByteBuffer.allocateDirect(this.bufferSize); + this.streamOffset = streamOffset; + try { + encryptor = codec.createEncryptor(); + } catch (GeneralSecurityException e) { + throw new IOException(e); + } + updateEncryptor(); + } + + public CryptoOutputStream(OutputStream out, CryptoCodec codec, + byte[] key, byte[] iv) throws IOException { + this(out, codec, key, iv, 0); + } + + public CryptoOutputStream(OutputStream out, CryptoCodec codec, + byte[] key, byte[] iv, long streamOffset) throws IOException { + this(out, codec, CryptoStreamUtils.getBufferSize(codec.getConf()), + key, iv, streamOffset); + } + + public OutputStream getWrappedStream() { + return out; + } + + public void setWrappedStream(OutputStream out) { + this.out = out; + } + + /** + * Encryption is buffer based. + * If there is enough room in {@link #inBuffer}, then write to this buffer. + * If {@link #inBuffer} is full, then do encryption and write data to the + * underlying stream. + * @param b the data. + * @param off the start offset in the data. + * @param len the number of bytes to write. + * @throws IOException + */ + @Override + public synchronized void write(byte[] b, int off, int len) throws IOException { + checkStream(); + if (b == null) { + throw new NullPointerException(); + } else if (off < 0 || len < 0 || off > b.length || + len > b.length - off) { + throw new IndexOutOfBoundsException(); + } + while (len > 0) { + final int remaining = inBuffer.remaining(); + if (len < remaining) { + inBuffer.put(b, off, len); + len = 0; + } else { + inBuffer.put(b, off, remaining); + off += remaining; + len -= remaining; + encrypt(); + } + } + } + + /** + * Do the encryption, input is {@link #inBuffer} and output is + * {@link #outBuffer}. + */ + private void encrypt() throws IOException { + Preconditions.checkState(inBuffer.position() >= padding); + if (inBuffer.position() == padding) { + // There is no real data in the inBuffer. + return; + } + inBuffer.flip(); + outBuffer.clear(); + encryptor.encrypt(inBuffer, outBuffer); + inBuffer.clear(); + outBuffer.flip(); + if (padding > 0) { + /* + * The plain text and cipher text have a 1:1 mapping, they start at the + * same position. + */ + outBuffer.position(padding); + padding = 0; + } + final int len = outBuffer.remaining(); + + /* + * If underlying stream supports {@link ByteBuffer} write in future, needs + * refine here. + */ + final byte[] tmp = getTmpBuf(); + outBuffer.get(tmp, 0, len); + out.write(tmp, 0, len); + + streamOffset += len; + if (encryptor.isContextReset()) { + /* + * This code is generally not executed since the encryptor usually + * maintains encryption context (e.g. the counter) internally. However, + * some implementations can't maintain context so a re-init is necessary + * after each encryption call. + */ + updateEncryptor(); + } + } + + /** Update the {@link #encryptor}: calculate counter and {@link #padding}. */ + private void updateEncryptor() throws IOException { + final long counter = + streamOffset / codec.getCipherSuite().getAlgorithmBlockSize(); + padding = + (byte)(streamOffset % codec.getCipherSuite().getAlgorithmBlockSize()); + inBuffer.position(padding); // Set proper position for input data. + codec.calculateIV(initIV, counter, iv); + encryptor.init(key, iv); + } + + private byte[] tmpBuf; + private byte[] getTmpBuf() { + if (tmpBuf == null) { + tmpBuf = new byte[bufferSize]; + } + return tmpBuf; + } + + @Override + public synchronized void close() throws IOException { + if (closed) { + return; + } + try { + super.close(); + freeBuffers(); + } finally { + closed = true; + } + } + + /** + * To flush, we need to encrypt the data in the buffer and write to the + * underlying stream, then do the flush. + */ + @Override + public synchronized void flush() throws IOException { + checkStream(); + encrypt(); + super.flush(); + } + + @Override + public void write(int b) throws IOException { + oneByteBuf[0] = (byte)(b & 0xff); + write(oneByteBuf, 0, oneByteBuf.length); + } + + private void checkStream() throws IOException { + if (closed) { + throw new IOException("Stream closed"); + } + } + + @Override + public void setDropBehind(Boolean dropCache) throws IOException, + UnsupportedOperationException { + try { + ((CanSetDropBehind) out).setDropBehind(dropCache); + } catch (ClassCastException e) { + throw new UnsupportedOperationException("This stream does not " + + "support setting the drop-behind caching."); + } + } + + + public void sync() throws IOException { + hflush(); + } + + @Override + public void hflush() throws IOException { + flush(); + if (out instanceof Syncable) { + ((Syncable)out).hflush(); + } + } + + @Override + public void hsync() throws IOException { + flush(); + if (out instanceof Syncable) { + ((Syncable)out).hsync(); + } + } + + /** Forcibly free the direct buffers. */ + private void freeBuffers() { + CryptoStreamUtils.freeDB(inBuffer); + CryptoStreamUtils.freeDB(outBuffer); + } +} diff --git a/hbase-common/src/main/java/org/apache/hadoop/hbase/security/SaslCryptoCodec.java b/hbase-common/src/main/java/org/apache/hadoop/hbase/security/SaslCryptoCodec.java new file mode 100644 index 0000000..98ff115 --- /dev/null +++ b/hbase-common/src/main/java/org/apache/hadoop/hbase/security/SaslCryptoCodec.java @@ -0,0 +1,214 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.hadoop.hbase.security; + +import org.apache.hadoop.conf.Configuration; +import org.apache.hadoop.crypto.CipherOption; +import org.apache.hadoop.crypto.CryptoCodec; +import org.apache.hadoop.hbase.crypto.CryptoInputStream; +import org.apache.hadoop.hbase.crypto.CryptoOutputStream; + +import javax.crypto.Mac; +import javax.crypto.SecretKey; +import javax.crypto.spec.SecretKeySpec; +import javax.security.sasl.SaslException; +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.security.InvalidKeyException; +import java.security.NoSuchAlgorithmException; +import java.util.Arrays; + +/** + * Provide the functionality to allow for quality-of-protection (QOP) with + * integrity checking and privacy. It relies on CryptoInputStream and + * CryptoOutputStream to do decryption and encryption. + */ +public class SaslCryptoCodec { + private static final int MAC_LENGTH = 10; + private static final int SEQ_NUM_LENGTH = 4; + + private CryptoInputStream cIn; + private CryptoOutputStream cOut; + + private final Integrity integrity; + + public SaslCryptoCodec(Configuration conf, CipherOption cipherOption, + boolean isServer) throws IOException { + CryptoCodec codec = CryptoCodec.getInstance(conf, + cipherOption.getCipherSuite()); + byte[] inKey = cipherOption.getInKey(); + byte[] inIv = cipherOption.getInIv(); + byte[] outKey = cipherOption.getOutKey(); + byte[] outIv = cipherOption.getOutIv(); + cIn = new CryptoInputStream(null, codec, + isServer ? inKey : outKey, isServer ? inIv : outIv); + cOut = new CryptoOutputStream(new ByteArrayOutputStream(), codec, + isServer ? outKey : inKey, isServer ? outIv : inIv); + integrity = new Integrity(isServer ? outKey : inKey, + isServer ? inKey : outKey); + } + + public byte[] wrap(byte[] outgoing, int offset, int len) + throws SaslException { + // mac + byte[] mac = integrity.getHMAC(outgoing, offset, len); + integrity.incMySeqNum(); + + // encrypt + try { + cOut.write(outgoing, offset, len); + cOut.write(mac, 0, MAC_LENGTH); + cOut.flush(); + } catch (IOException ioe) { + throw new SaslException("Encrypt failed", ioe); + } + byte[] encrypted = ((ByteArrayOutputStream) cOut.getWrappedStream()) + .toByteArray(); + ((ByteArrayOutputStream) cOut.getWrappedStream()).reset(); + + // append seqNum used for mac + byte[] wrapped = new byte[encrypted.length + SEQ_NUM_LENGTH]; + System.arraycopy(encrypted, 0, wrapped, 0, encrypted.length); + System.arraycopy(integrity.getSeqNum(), 0, wrapped, + encrypted.length, SEQ_NUM_LENGTH); + + return wrapped; + } + + public byte[] unwrap(byte[] incoming, int offset, int len) + throws SaslException { + // get seqNum + byte[] peerSeqNum = new byte[SEQ_NUM_LENGTH]; + System.arraycopy(incoming, offset + len - SEQ_NUM_LENGTH, peerSeqNum, 0, + SEQ_NUM_LENGTH); + + // get msg and mac + byte[] msg = new byte[len - SEQ_NUM_LENGTH - MAC_LENGTH]; + byte[] mac = new byte[MAC_LENGTH]; + cIn.setWrappedStream(new ByteArrayInputStream(incoming, offset, + len - SEQ_NUM_LENGTH)); + try { + cIn.readFully(msg, 0, msg.length); + cIn.readFully(mac, 0, mac.length); + } catch (IOException ioe) { + throw new SaslException("Decrypt failed" , ioe); + } + + // check mac integrity and msg sequence + if (!integrity.comparePeerHMAC(mac, peerSeqNum, msg, 0, msg.length)) { + throw new SaslException("Unmatched MAC"); + } + if (!integrity.comparePeerSeqNum(peerSeqNum)) { + throw new SaslException("Out of order sequencing of messages. Got: " + + integrity.byteToInt(peerSeqNum) + " Expected: " + + integrity.peerSeqNum); + } + integrity.incPeerSeqNum(); + + return msg; + } + + /** + * Helper class for providing integrity protection. + */ + private static class Integrity { + + private int mySeqNum = 0; + private int peerSeqNum = 0; + private byte[] mySeqNumArray = new byte[SEQ_NUM_LENGTH]; + + private byte[] myKey; + private byte[] peerKey; + + Integrity(byte[] myKey, byte[] peerKey) throws IOException { + this.myKey = myKey; + this.peerKey = peerKey; + } + + byte[] getHMAC(byte[] msg, int start, int len) throws SaslException { + intToByte(mySeqNum); + return calculateHMAC(myKey, mySeqNumArray, msg, start, len); + } + + boolean comparePeerHMAC(byte[] expectedHMAC, byte[] seqNum, byte[] msg, + int start, int len) throws SaslException { + byte[] mac = calculateHMAC(peerKey, seqNum, msg, start, len); + return Arrays.equals(mac, expectedHMAC); + } + + boolean comparePeerSeqNum(byte[] seqNum) { + return this.peerSeqNum == byteToInt(seqNum); + } + + byte[] getSeqNum() { + return mySeqNumArray; + } + + void incMySeqNum() { + mySeqNum++; + } + + void incPeerSeqNum() { + peerSeqNum++; + } + + private byte[] calculateHMAC(byte[] key, byte[] seqNum, byte[] msg, + int start, int len) throws SaslException { + byte[] seqAndMsg = new byte[SEQ_NUM_LENGTH + len]; + System.arraycopy(seqNum, 0, seqAndMsg, 0, SEQ_NUM_LENGTH); + System.arraycopy(msg, start, seqAndMsg, SEQ_NUM_LENGTH, len); + + try { + SecretKey keyKi = new SecretKeySpec(key, "HmacMD5"); + Mac m = Mac.getInstance("HmacMD5"); + m.init(keyKi); + m.update(seqAndMsg); + byte[] hMacMd5 = m.doFinal(); + + /* First 10 bytes of HMAC_MD5 digest */ + byte[] macBuffer = new byte[MAC_LENGTH]; + System.arraycopy(hMacMd5, 0, macBuffer, 0, MAC_LENGTH); + + return macBuffer; + } catch (InvalidKeyException e) { + throw new SaslException("Invalid bytes used for key of HMAC-MD5 hash.", + e); + } catch (NoSuchAlgorithmException e) { + throw new SaslException("Error creating instance of MD5 MAC algorithm", + e); + } + } + + private void intToByte(int num) { + for(int i = 3; i >= 0; i--) { + mySeqNumArray[i] = (byte)(num & 0xff); + num >>>= 8; + } + } + + private int byteToInt(byte[] seqNum) { + int answer = 0; + for (int i = 0; i < 4; i++) { + answer <<= 8; + answer |= ((int)seqNum[i] & 0xff); + } + return answer; + } + } +} diff --git a/hbase-common/src/main/java/org/apache/hadoop/hbase/security/SaslInputStream.java b/hbase-common/src/main/java/org/apache/hadoop/hbase/security/SaslInputStream.java new file mode 100644 index 0000000..eb48ccf --- /dev/null +++ b/hbase-common/src/main/java/org/apache/hadoop/hbase/security/SaslInputStream.java @@ -0,0 +1,394 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.hadoop.hbase.security; + +import java.io.DataInputStream; +import java.io.EOFException; +import java.io.InputStream; +import java.io.IOException; +import java.nio.ByteBuffer; +import java.nio.channels.ReadableByteChannel; + +import javax.security.sasl.Sasl; +import javax.security.sasl.SaslClient; +import javax.security.sasl.SaslException; +import javax.security.sasl.SaslServer; + +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; +import org.apache.hadoop.classification.InterfaceAudience; +import org.apache.hadoop.classification.InterfaceStability; + +/** + * A SaslInputStream is composed of an InputStream and a SaslServer (or + * SaslClient) so that read() methods return data that are read in from the + * underlying InputStream but have been additionally processed by the SaslServer + * (or SaslClient) object. The SaslServer (or SaslClient) object must be fully + * initialized before being used by a SaslInputStream. + */ +@InterfaceAudience.LimitedPrivate({"HDFS", "MapReduce"}) +@InterfaceStability.Evolving +public class SaslInputStream extends InputStream implements ReadableByteChannel { + public static final Log LOG = LogFactory.getLog(SaslInputStream.class); + + private final DataInputStream inStream; + /** Should we wrap the communication channel? */ + private final boolean useWrap; + + /* + * data read from the underlying input stream before being processed by SASL + */ + private byte[] saslToken; + private final SaslClient saslClient; + private final SaslServer saslServer; + private final SaslCryptoCodec saslCodec; + private byte[] lengthBuf = new byte[4]; + /* + * buffer holding data that have been processed by SASL, but have not been + * read out + */ + private byte[] obuffer; + // position of the next "new" byte + private int ostart = 0; + // position of the last "new" byte + private int ofinish = 0; + // whether or not this stream is open + private boolean isOpen = true; + + private static int unsignedBytesToInt(byte[] buf) { + if (buf.length != 4) { + throw new IllegalArgumentException( + "Cannot handle byte array other than 4 bytes"); + } + int result = 0; + for (int i = 0; i < 4; i++) { + result <<= 8; + result |= ((int) buf[i] & 0xff); + } + return result; + } + + /** + * Read more data and get them processed
+ * Entry condition: ostart = ofinish
+ * Exit condition: ostart <= ofinish
+ * + * return (ofinish-ostart) (we have this many bytes for you), 0 (no data now, + * but could have more later), or -1 (absolutely no more data) + */ + private int readMoreData() throws IOException { + try { + inStream.readFully(lengthBuf); + int length = unsignedBytesToInt(lengthBuf); + if (LOG.isDebugEnabled()) + LOG.debug("Actual length is " + length); + saslToken = new byte[length]; + inStream.readFully(saslToken); + } catch (EOFException e) { + return -1; + } + try { + if (saslServer != null) { // using saslServer + obuffer = saslServer.unwrap(saslToken, 0, saslToken.length); + } else if (saslClient != null) { // using saslClient + obuffer = saslClient.unwrap(saslToken, 0, saslToken.length); + } else { + obuffer = saslCodec.unwrap(saslToken, 0, saslToken.length); + } + } catch (SaslException se) { + try { + disposeSasl(); + } catch (SaslException ignored) { + } + throw se; + } + ostart = 0; + if (obuffer == null) + ofinish = 0; + else + ofinish = obuffer.length; + return ofinish; + } + + /** + * Disposes of any system resources or security-sensitive information Sasl + * might be using. + * + * @exception SaslException + * if a SASL error occurs. + */ + private void disposeSasl() throws SaslException { + if (saslClient != null) { + saslClient.dispose(); + } + if (saslServer != null) { + saslServer.dispose(); + } + } + + /** + * Constructs a SASLInputStream from an InputStream and a SaslServer
+ * Note: if the specified InputStream or SaslServer is null, a + * NullPointerException may be thrown later when they are used. + * + * @param inStream + * the InputStream to be processed + * @param saslServer + * an initialized SaslServer object + */ + public SaslInputStream(InputStream inStream, SaslServer saslServer) { + this.inStream = new DataInputStream(inStream); + this.saslServer = saslServer; + this.saslClient = null; + this.saslCodec = null; + String qop = (String) saslServer.getNegotiatedProperty(Sasl.QOP); + this.useWrap = qop != null && !"auth".equalsIgnoreCase(qop); + } + + /** + * Constructs a SASLInputStream from an InputStream and a SaslClient
+ * Note: if the specified InputStream or SaslClient is null, a + * NullPointerException may be thrown later when they are used. + * + * @param inStream + * the InputStream to be processed + * @param saslClient + * an initialized SaslClient object + */ + public SaslInputStream(InputStream inStream, SaslClient saslClient) { + this.inStream = new DataInputStream(inStream); + this.saslServer = null; + this.saslClient = saslClient; + this.saslCodec = null; + String qop = (String) saslClient.getNegotiatedProperty(Sasl.QOP); + this.useWrap = qop != null && !"auth".equalsIgnoreCase(qop); + } + + public SaslInputStream(InputStream inStream, SaslCryptoCodec saslCodec) { + this.inStream = new DataInputStream(inStream); + this.saslServer = null; + this.saslClient = null; + this.saslCodec = saslCodec; + this.useWrap = true; + } + + /** + * Reads the next byte of data from this input stream. The value byte is + * returned as an int in the range 0 to + * 255. If no byte is available because the end of the stream has + * been reached, the value -1 is returned. This method blocks + * until input data is available, the end of the stream is detected, or an + * exception is thrown. + *

+ * + * @return the next byte of data, or -1 if the end of the stream + * is reached. + * @exception IOException + * if an I/O error occurs. + */ + @Override + public int read() throws IOException { + if (!useWrap) { + return inStream.read(); + } + if (ostart >= ofinish) { + // we loop for new data as we are blocking + int i = 0; + while (i == 0) + i = readMoreData(); + if (i == -1) + return -1; + } + return ((int) obuffer[ostart++] & 0xff); + } + + /** + * Reads up to b.length bytes of data from this input stream into + * an array of bytes. + *

+ * The read method of InputStream calls the + * read method of three arguments with the arguments + * b, 0, and b.length. + * + * @param b + * the buffer into which the data is read. + * @return the total number of bytes read into the buffer, or -1 + * is there is no more data because the end of the stream has been + * reached. + * @exception IOException + * if an I/O error occurs. + */ + @Override + public int read(byte[] b) throws IOException { + return read(b, 0, b.length); + } + + /** + * Reads up to len bytes of data from this input stream into an + * array of bytes. This method blocks until some input is available. If the + * first argument is null, up to len bytes are read + * and discarded. + * + * @param b + * the buffer into which the data is read. + * @param off + * the start offset of the data. + * @param len + * the maximum number of bytes read. + * @return the total number of bytes read into the buffer, or -1 + * if there is no more data because the end of the stream has been + * reached. + * @exception IOException + * if an I/O error occurs. + */ + @Override + public int read(byte[] b, int off, int len) throws IOException { + if (!useWrap) { + return inStream.read(b, off, len); + } + if (ostart >= ofinish) { + // we loop for new data as we are blocking + int i = 0; + while (i == 0) + i = readMoreData(); + if (i == -1) + return -1; + } + if (len <= 0) { + return 0; + } + int available = ofinish - ostart; + if (len < available) + available = len; + if (b != null) { + System.arraycopy(obuffer, ostart, b, off, available); + } + ostart = ostart + available; + return available; + } + + /** + * Skips n bytes of input from the bytes that can be read from + * this input stream without blocking. + * + *

+ * Fewer bytes than requested might be skipped. The actual number of bytes + * skipped is equal to n or the result of a call to + * {@link #available() available}, whichever is smaller. If + * n is less than zero, no bytes are skipped. + * + *

+ * The actual number of bytes skipped is returned. + * + * @param n + * the number of bytes to be skipped. + * @return the actual number of bytes skipped. + * @exception IOException + * if an I/O error occurs. + */ + @Override + public long skip(long n) throws IOException { + if (!useWrap) { + return inStream.skip(n); + } + int available = ofinish - ostart; + if (n > available) { + n = available; + } + if (n < 0) { + return 0; + } + ostart += n; + return n; + } + + /** + * Returns the number of bytes that can be read from this input stream without + * blocking. The available method of InputStream + * returns 0. This method should be overridden by + * subclasses. + * + * @return the number of bytes that can be read from this input stream without + * blocking. + * @exception IOException + * if an I/O error occurs. + */ + @Override + public int available() throws IOException { + if (!useWrap) { + return inStream.available(); + } + return (ofinish - ostart); + } + + /** + * Closes this input stream and releases any system resources associated with + * the stream. + *

+ * The close method of SASLInputStream calls the + * close method of its underlying input stream. + * + * @exception IOException + * if an I/O error occurs. + */ + @Override + public void close() throws IOException { + disposeSasl(); + ostart = 0; + ofinish = 0; + inStream.close(); + isOpen = false; + } + + /** + * Tests if this input stream supports the mark and + * reset methods, which it does not. + * + * @return false, since this class does not support the + * mark and reset methods. + */ + @Override + public boolean markSupported() { + return false; + } + + @Override + public boolean isOpen() { + return isOpen; + } + + @Override + public int read(ByteBuffer dst) throws IOException { + int bytesRead = 0; + if (dst.hasArray()) { + bytesRead = read(dst.array(), dst.arrayOffset() + dst.position(), + dst.remaining()); + if (bytesRead > -1) { + dst.position(dst.position() + bytesRead); + } + } else { + byte[] buf = new byte[dst.remaining()]; + bytesRead = read(buf); + if (bytesRead > -1) { + dst.put(buf, 0, bytesRead); + } + } + return bytesRead; + } +} diff --git a/hbase-common/src/main/java/org/apache/hadoop/hbase/security/SaslOutputStream.java b/hbase-common/src/main/java/org/apache/hadoop/hbase/security/SaslOutputStream.java new file mode 100644 index 0000000..6620aed --- /dev/null +++ b/hbase-common/src/main/java/org/apache/hadoop/hbase/security/SaslOutputStream.java @@ -0,0 +1,235 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.hadoop.hbase.security; + +import java.io.BufferedOutputStream; +import java.io.ByteArrayOutputStream; +import java.io.DataOutputStream; +import java.io.IOException; +import java.io.OutputStream; + +import javax.security.sasl.Sasl; +import javax.security.sasl.SaslClient; +import javax.security.sasl.SaslException; +import javax.security.sasl.SaslServer; + +import org.apache.hadoop.classification.InterfaceAudience; +import org.apache.hadoop.classification.InterfaceStability; + +/** + * A SaslOutputStream is composed of an OutputStream and a SaslServer (or + * SaslClient) so that write() methods first process the data before writing + * them out to the underlying OutputStream. The SaslServer (or SaslClient) + * object must be fully initialized before being used by a SaslOutputStream. + */ +@InterfaceAudience.LimitedPrivate({"HDFS", "MapReduce"}) +@InterfaceStability.Evolving +public class SaslOutputStream extends OutputStream { + + private final OutputStream outStream; + // processed data ready to be written out + private byte[] saslToken; + + private final SaslClient saslClient; + private final SaslServer saslServer; + private final SaslCryptoCodec saslCodec; + // buffer holding one byte of incoming data + private final byte[] ibuffer = new byte[1]; + private final boolean useWrap; + + /** + * Constructs a SASLOutputStream from an OutputStream and a SaslServer
+ * Note: if the specified OutputStream or SaslServer is null, a + * NullPointerException may be thrown later when they are used. + * + * @param outStream + * the OutputStream to be processed + * @param saslServer + * an initialized SaslServer object + */ + public SaslOutputStream(OutputStream outStream, SaslServer saslServer) { + this.saslServer = saslServer; + this.saslClient = null; + this.saslCodec = null; + String qop = (String) saslServer.getNegotiatedProperty(Sasl.QOP); + this.useWrap = qop != null && !"auth".equalsIgnoreCase(qop); + if (useWrap) { + this.outStream = new BufferedOutputStream(outStream, 64*1024); + } else { + this.outStream = outStream; + } + } + + /** + * Constructs a SASLOutputStream from an OutputStream and a SaslClient
+ * Note: if the specified OutputStream or SaslClient is null, a + * NullPointerException may be thrown later when they are used. + * + * @param outStream + * the OutputStream to be processed + * @param saslClient + * an initialized SaslClient object + */ + public SaslOutputStream(OutputStream outStream, SaslClient saslClient) { + this.saslServer = null; + this.saslClient = saslClient; + this.saslCodec = null; + String qop = (String) saslClient.getNegotiatedProperty(Sasl.QOP); + this.useWrap = qop != null && !"auth".equalsIgnoreCase(qop); + if (useWrap) { + this.outStream = new BufferedOutputStream(outStream, 64*1024); + } else { + this.outStream = outStream; + } + } + + public SaslOutputStream(OutputStream outStream, SaslCryptoCodec saslCodec) { + this.saslServer = null; + this.saslClient = null; + this.saslCodec = saslCodec; + this.useWrap = true; + if (useWrap) { + this.outStream = new BufferedOutputStream(outStream, 64*1024); + } else { + this.outStream = outStream; + } + } + + /** + * Disposes of any system resources or security-sensitive information Sasl + * might be using. + * + * @exception SaslException + * if a SASL error occurs. + */ + private void disposeSasl() throws SaslException { + if (saslClient != null) { + saslClient.dispose(); + } + if (saslServer != null) { + saslServer.dispose(); + } + } + + /** + * Writes the specified byte to this output stream. + * + * @param b + * the byte. + * @exception IOException + * if an I/O error occurs. + */ + @Override + public void write(int b) throws IOException { + if (!useWrap) { + outStream.write(b); + return; + } + ibuffer[0] = (byte) b; + write(ibuffer, 0, 1); + } + + /** + * Writes b.length bytes from the specified byte array to this + * output stream. + *

+ * The write method of SASLOutputStream calls the + * write method of three arguments with the three arguments + * b, 0, and b.length. + * + * @param b + * the data. + * @exception NullPointerException + * if b is null. + * @exception IOException + * if an I/O error occurs. + */ + @Override + public void write(byte[] b) throws IOException { + write(b, 0, b.length); + } + + /** + * Writes len bytes from the specified byte array starting at + * offset off to this output stream. + * + * @param inBuf + * the data. + * @param off + * the start offset in the data. + * @param len + * the number of bytes to write. + * @exception IOException + * if an I/O error occurs. + */ + @Override + public void write(byte[] inBuf, int off, int len) throws IOException { + if (!useWrap) { + outStream.write(inBuf, off, len); + return; + } + try { + if (saslServer != null) { // using saslServer + saslToken = saslServer.wrap(inBuf, off, len); + } else if (saslClient != null) { // using saslClient + saslToken = saslClient.wrap(inBuf, off, len); + } else { + saslToken = saslCodec.wrap(inBuf, off, len); + } + } catch (SaslException se) { + try { + disposeSasl(); + } catch (SaslException ignored) { + } + throw se; + } + if (saslToken != null) { + ByteArrayOutputStream byteOut = new ByteArrayOutputStream(); + DataOutputStream dout = new DataOutputStream(byteOut); + dout.writeInt(saslToken.length); + outStream.write(byteOut.toByteArray()); + outStream.write(saslToken, 0, saslToken.length); + saslToken = null; + } + } + + /** + * Flushes this output stream + * + * @exception IOException + * if an I/O error occurs. + */ + @Override + public void flush() throws IOException { + outStream.flush(); + } + + /** + * Closes this output stream and releases any system resources associated with + * this stream. + * + * @exception IOException + * if an I/O error occurs. + */ + @Override + public void close() throws IOException { + disposeSasl(); + outStream.close(); + } +} diff --git a/hbase-server/src/main/java/org/apache/hadoop/hbase/ipc/RpcServer.java b/hbase-server/src/main/java/org/apache/hadoop/hbase/ipc/RpcServer.java index 0df5097..86b1c05 100644 --- a/hbase-server/src/main/java/org/apache/hadoop/hbase/ipc/RpcServer.java +++ b/hbase-server/src/main/java/org/apache/hadoop/hbase/ipc/RpcServer.java @@ -69,6 +69,7 @@ import javax.security.sasl.SaslServer; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; import org.apache.hadoop.conf.Configuration; +import org.apache.hadoop.crypto.CipherOption; import org.apache.hadoop.hbase.CallQueueTooBigException; import org.apache.hadoop.hbase.CellScanner; import org.apache.hadoop.hbase.DoNotRetryIOException; @@ -96,16 +97,9 @@ import org.apache.hadoop.hbase.protobuf.generated.RPCProtos.ExceptionResponse; import org.apache.hadoop.hbase.protobuf.generated.RPCProtos.RequestHeader; import org.apache.hadoop.hbase.protobuf.generated.RPCProtos.ResponseHeader; import org.apache.hadoop.hbase.protobuf.generated.RPCProtos.UserInformation; -import org.apache.hadoop.hbase.security.AccessDeniedException; -import org.apache.hadoop.hbase.security.AuthMethod; -import org.apache.hadoop.hbase.security.HBasePolicyProvider; -import org.apache.hadoop.hbase.security.HBaseSaslRpcServer; +import org.apache.hadoop.hbase.security.*; import org.apache.hadoop.hbase.security.HBaseSaslRpcServer.SaslDigestCallbackHandler; import org.apache.hadoop.hbase.security.HBaseSaslRpcServer.SaslGssCallbackHandler; -import org.apache.hadoop.hbase.security.SaslStatus; -import org.apache.hadoop.hbase.security.SaslUtil; -import org.apache.hadoop.hbase.security.User; -import org.apache.hadoop.hbase.security.UserProvider; import org.apache.hadoop.hbase.security.token.AuthenticationTokenSecretManager; import org.apache.hadoop.hbase.util.Bytes; import org.apache.hadoop.hbase.util.Counter; @@ -547,8 +541,14 @@ public class RpcServer implements RpcServerInterface, ConfigurationObserver { byte [] token; // synchronization may be needed since there can be multiple Handler // threads using saslServer to wrap responses. - synchronized (connection.saslServer) { - token = connection.saslServer.wrap(responseBytes, 0, responseBytes.length); + if (connection.saslCodec != null) { + synchronized (connection.saslCodec) { + token = connection.saslCodec.wrap(responseBytes, 0, responseBytes.length); + } + } else { + synchronized (connection.saslServer) { + token = connection.saslServer.wrap(responseBytes, 0, responseBytes.length); + } } if (LOG.isTraceEnabled()) { LOG.trace("Adding saslServer wrapped token of size " + token.length @@ -1230,6 +1230,9 @@ public class RpcServer implements RpcServerInterface, ConfigurationObserver { private ByteBuffer unwrappedDataLengthBuffer = ByteBuffer.allocate(4); boolean useSasl; SaslServer saslServer; + SaslCryptoCodec saslCodec; + private boolean useNegotiatedCipher = false; + private String clientCiphers; private boolean useWrap = false; // Fake 'call' for failed authorization response private static final int AUTHORIZATION_FAILED_CALLID = -1; @@ -1351,7 +1354,12 @@ public class RpcServer implements RpcServerInterface, ConfigurationObserver { processOneRpc(saslToken); } else { byte[] b = saslToken.array(); - byte [] plaintextData = saslServer.unwrap(b, saslToken.position(), saslToken.limit()); + byte [] plaintextData = null; + if (saslCodec != null) { + plaintextData = saslCodec.unwrap(b, saslToken.position(), saslToken.limit()); + } else { + plaintextData = saslServer.unwrap(b, saslToken.position(), saslToken.limit()); + } processUnwrappedData(plaintextData); } } else { @@ -1414,7 +1422,7 @@ public class RpcServer implements RpcServerInterface, ConfigurationObserver { } cause = cause.getCause(); } - doRawSaslReply(SaslStatus.ERROR, null, sendToClient.getClass().getName(), + doRawSaslReply(SaslStatus.ERROR, null, null, sendToClient.getClass().getName(), sendToClient.getLocalizedMessage()); metrics.authenticationFailure(); String clientIP = this.toString(); @@ -1427,8 +1435,29 @@ public class RpcServer implements RpcServerInterface, ConfigurationObserver { LOG.debug("Will send token of size " + replyToken.length + " from saslServer."); } - doRawSaslReply(SaslStatus.SUCCESS, new BytesWritable(replyToken), null, + doRawSaslReply(SaslStatus.SUCCESS, new BytesWritable(replyToken), null, null, null); + } else if (useNegotiatedCipher && saslServer.isComplete() && + SaslUtil.isNegotiatedQopPrivacy(saslServer)) { + // Negotiate a cipher option + CipherOption cipherOption = SaslUtil.negotiateCipherOption( + conf, SaslUtil.getCipherOptions(clientCiphers)); + if (LOG.isDebugEnabled()) { + if (cipherOption == null) { + // No cipher suite is negotiated + LOG.debug("Server not using any cipher suite" + + " with client " + hostAddress); + } else { + LOG.debug("Server using cipher suite " + + cipherOption.getCipherSuite().getName() + + " with client " + hostAddress); + } + } + if (cipherOption != null) { + saslCodec = new SaslCryptoCodec(conf, cipherOption, true); + } + CipherOption wrappedCipherOption = SaslUtil.wrap(cipherOption, saslServer); + doRawSaslReply(SaslStatus.SUCCESS, null, wrappedCipherOption, null, null); } if (saslServer.isComplete()) { String qop = (String) saslServer.getNegotiatedProperty(Sasl.QOP); @@ -1449,8 +1478,8 @@ public class RpcServer implements RpcServerInterface, ConfigurationObserver { /** * No protobuf encoding of raw sasl messages */ - private void doRawSaslReply(SaslStatus status, Writable rv, - String errorClass, String error) throws IOException { + private void doRawSaslReply(SaslStatus status, Writable rv, CipherOption cipherOption, + String errorClass, String error) throws IOException { ByteBufferOutputStream saslResponse = null; DataOutputStream out = null; try { @@ -1460,7 +1489,18 @@ public class RpcServer implements RpcServerInterface, ConfigurationObserver { out = new DataOutputStream(saslResponse); out.writeInt(status.state); // write status if (status == SaslStatus.SUCCESS) { - rv.write(out); + if (cipherOption != null) { + out.writeInt(SaslUtil.USE_NEGOTIATED_CIPHER); + WritableUtils.writeString(out, cipherOption.getCipherSuite().getName()); + new BytesWritable(cipherOption.getInKey()).write(out); + new BytesWritable(cipherOption.getInIv()).write(out); + new BytesWritable(cipherOption.getOutKey()).write(out); + new BytesWritable(cipherOption.getOutIv()).write(out); + } else if (rv != null) { + rv.write(out); + } else { + out.writeInt(0); // Write 0 length (can indicate no ciphers) + } } else { WritableUtils.writeString(out, errorClass); WritableUtils.writeString(out, error); @@ -1529,7 +1569,7 @@ public class RpcServer implements RpcServerInterface, ConfigurationObserver { } if (!isSecurityEnabled && authMethod != AuthMethod.SIMPLE) { doRawSaslReply(SaslStatus.SUCCESS, new IntWritable( - SaslUtil.SWITCH_TO_SIMPLE_AUTH), null, null); + SaslUtil.SWITCH_TO_SIMPLE_AUTH), null, null, null); authMethod = AuthMethod.SIMPLE; // client has already sent the initial Sasl message and we // should ignore it. Both client and server should fall back @@ -1593,6 +1633,41 @@ public class RpcServer implements RpcServerInterface, ConfigurationObserver { if (!useWrap) { //covers the !useSasl too dataLengthBuffer.clear(); return 0; //ping message + } else if (dataLength == SaslUtil.USE_NEGOTIATED_CIPHER) { + useNegotiatedCipher = true; + + // Read the client ciphers length. + dataLengthBuffer.clear(); + count = read4Bytes(); + if (count < 0 || dataLengthBuffer.remaining() > 0) { + return count; + } + dataLengthBuffer.flip(); + dataLength = dataLengthBuffer.getInt(); + if (dataLength < 0) { // A data length of zero is legal. + throw new IllegalArgumentException("Unexpected data length " + + dataLength + "!! from " + getHostAddress() + " while reading client ciphers"); + } + + // Read the client ciphers. + ByteBuffer ciphers = ByteBuffer.allocate(dataLength); + count = channelRead(channel, ciphers); + if (count < 0 || ciphers.remaining() > 0) { + return count; + } + clientCiphers = new String(ciphers.array(), "UTF-8"); + if (LOG.isDebugEnabled()) { + LOG.debug("Have read client ciphers: " + clientCiphers); + } + + // Read the data length. + dataLengthBuffer.clear(); + count = read4Bytes(); + if (count < 0 || dataLengthBuffer.remaining() > 0) { + return count; + } + dataLengthBuffer.flip(); + dataLength = dataLengthBuffer.getInt(); } } if (dataLength < 0) { // A data length of zero is legal.