From 5fb9ba73e10abdb4384726bc8859914da4b3a254 Mon Sep 17 00:00:00 2001 From: Sriharsha Chintalapani Date: Tue, 10 Mar 2015 15:50:24 -0700 Subject: [PATCH 1/2] KAFKA-1684. Implement TLS/SSL authentication. --- core/src/main/scala/kafka/network/Channel.scala | 126 ++++++ .../main/scala/kafka/network/SocketServer.scala | 145 ++++--- .../main/scala/kafka/network/ssl/SSLChannel.scala | 429 +++++++++++++++++++++ .../kafka/network/ssl/SSLConnectionConfig.scala | 52 +++ core/src/main/scala/kafka/server/KafkaConfig.scala | 32 +- core/src/main/scala/kafka/server/KafkaServer.scala | 2 + core/src/main/scala/kafka/utils/SSLAuthUtils.scala | 97 +++++ .../unit/kafka/network/SocketServerTest.scala | 37 +- .../kafka/server/KafkaConfigConfigDefTest.scala | 7 +- .../test/scala/unit/kafka/utils/TestSSLUtils.scala | 178 +++++++++ 10 files changed, 1051 insertions(+), 54 deletions(-) create mode 100644 core/src/main/scala/kafka/network/Channel.scala create mode 100644 core/src/main/scala/kafka/network/ssl/SSLChannel.scala create mode 100644 core/src/main/scala/kafka/network/ssl/SSLConnectionConfig.scala create mode 100644 core/src/main/scala/kafka/utils/SSLAuthUtils.scala create mode 100644 core/src/test/scala/unit/kafka/utils/TestSSLUtils.scala diff --git a/core/src/main/scala/kafka/network/Channel.scala b/core/src/main/scala/kafka/network/Channel.scala new file mode 100644 index 0000000..48133ca --- /dev/null +++ b/core/src/main/scala/kafka/network/Channel.scala @@ -0,0 +1,126 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package kafka.network + +import java.io.IOException +import java.nio.ByteBuffer +import java.nio.channels.ReadableByteChannel +import java.nio.channels.GatheringByteChannel +import java.nio.channels.SelectionKey +import java.nio.channels.Selector +import java.nio.channels.SocketChannel +import java.nio.channels.ClosedChannelException +import java.nio.channels.spi.AbstractSelectableChannel + +import kafka.utils.Logging + +class Channel( var socketChannel: SocketChannel ) extends ReadableByteChannel with GatheringByteChannel with Logging { + + /** + * Returns true if the network buffer has been flushed outu and is empty. + */ + def flush() : Boolean = { + true + } + + /** + * Closes this channel. + * + * @throws IOException If an I/O error occurs + */ + @throws(classOf[IOException]) + def close() { + socketChannel.socket().close() + socketChannel.close() + } + + @throws(classOf[IOException]) + def close(force: Boolean) { + if(isOpen || force) close() + } + + /** + * Tells wheter or not this channel is open. + */ + override def isOpen: Boolean = { + socketChannel.isOpen() + } + + /** + * Writes a sequence of bytes to this channel from the given buffer. + */ + @throws(classOf[IOException]) + override def write(src: ByteBuffer) : Int = { + socketChannel.write(src) + } + + @throws(classOf[IOException]) + override def write(srcs: Array[ByteBuffer]) : Long = { + socketChannel.write(srcs) + } + + @throws(classOf[IOException]) + override def write(srcs: Array[ByteBuffer], offset: Int, length: Int) : Long = { + socketChannel.write(srcs, offset, length) + } + + @throws(classOf[IOException]) + override def read(dst: ByteBuffer) : Int = { + socketChannel.read(dst) + } + + def getIOChannel : SocketChannel = { + socketChannel + } + + def isHandshakeComplete(): Boolean = { + true + } + + @throws(classOf[ClosedChannelException]) + def register(selector: Selector, ops: Int): SelectionKey = { + socketChannel.register(selector, ops) + } + + @throws(classOf[IOException]) + def flushOutbound: Boolean = { + false + } + + + /** + * Performs SSL handshake hence is a no-op for the non-secure + * implementation + * @param read Unused in non-secure implementation + * @param write Unused in non-secure implementation + * @return Always return 0 + * @throws IOException + */ + @throws(classOf[IOException]) + def handshake(read: Boolean, write: Boolean): Int = { + return 0 + } + + override def toString(): String = { + super.toString()+":"+this.socketChannel.toString() + } + + def getOutboundRemaining(): Int = { + return 0 + } +} diff --git a/core/src/main/scala/kafka/network/SocketServer.scala b/core/src/main/scala/kafka/network/SocketServer.scala index 76ce41a..fdc6bd8 100644 --- a/core/src/main/scala/kafka/network/SocketServer.scala +++ b/core/src/main/scala/kafka/network/SocketServer.scala @@ -23,12 +23,16 @@ import java.util.concurrent.atomic._ import java.net._ import java.io._ import java.nio.channels._ +import javax.net.ssl.SSLEngine import scala.collection._ import kafka.common.KafkaException import kafka.metrics.KafkaMetricsGroup import kafka.utils._ +import kafka.network.ssl._ +import kafka.utils.SSLAuthUtils + import com.yammer.metrics.core.{Gauge, Meter} /** @@ -47,11 +51,13 @@ class SocketServer(val brokerId: Int, val maxRequestSize: Int = Int.MaxValue, val maxConnectionsPerIp: Int = Int.MaxValue, val connectionsMaxIdleMs: Long, + val sslEnable: Boolean = false, + val sslConfigFilePath: String = null, val maxConnectionsPerIpOverrides: Map[String, Int] ) extends Logging with KafkaMetricsGroup { this.logIdent = "[Socket Server on Broker " + brokerId + "], " private val time = SystemTime private val processors = new Array[Processor](numProcessorThreads) - @volatile private var acceptor: Acceptor = null + @volatile private var acceptors: List[Acceptor] = List() val requestChannel = new RequestChannel(numProcessorThreads, maxQueuedRequests) /* a meter to track the average free capacity of the network processors */ @@ -63,12 +69,12 @@ class SocketServer(val brokerId: Int, def startup() { val quotas = new ConnectionQuotas(maxConnectionsPerIp, maxConnectionsPerIpOverrides) for(i <- 0 until numProcessorThreads) { - processors(i) = new Processor(i, - time, - maxRequestSize, + processors(i) = new Processor(i, + time, + maxRequestSize, aggregateIdleMeter, newMeter("IdlePercent", "percent", TimeUnit.NANOSECONDS, Map("networkProcessor" -> i.toString)), - numProcessorThreads, + numProcessorThreads, requestChannel, quotas, connectionsMaxIdleMs) @@ -81,12 +87,18 @@ class SocketServer(val brokerId: Int, // register the processor threads for notification of responses requestChannel.addResponseListener((id:Int) => processors(id).wakeup()) - + // start accepting connections - this.acceptor = new Acceptor(host, port, processors, sendBufferSize, recvBufferSize, quotas) - Utils.newThread("kafka-socket-acceptor", acceptor, false).start() - acceptor.awaitStartup - info("Started") + this.acceptors ++= List(new Acceptor(host, port, processors, sendBufferSize, recvBufferSize, quotas)) + if (sslEnable) { + val sslConnectionConfig = new SSLConnectionConfig(sslConfigFilePath) + this.acceptors ++= List(new Acceptor(sslConnectionConfig.host, sslConnectionConfig.port, processors, sendBufferSize, recvBufferSize, quotas, sslConnectionConfig, sslEnable)) + } + for (acceptor <- acceptors) { + Utils.newThread("kafka-socket-acceptor", acceptor, false).start() + acceptor.awaitStartup + info("Started") + } } /** @@ -94,8 +106,10 @@ class SocketServer(val brokerId: Int, */ def shutdown() = { info("Shutting down") - if(acceptor != null) - acceptor.shutdown() + for(acceptor <- acceptors) { + if(acceptor != null) + acceptor.shutdown() + } for(processor <- processors) processor.shutdown() info("Shutdown completed") @@ -142,12 +156,12 @@ private[kafka] abstract class AbstractServerThread(connectionQuotas: ConnectionQ * Is the server still running? */ protected def isRunning = alive.get - + /** * Wakeup the thread for selection. */ def wakeup() = selector.wakeup() - + /** * Close the given key and associated socket */ @@ -158,7 +172,7 @@ private[kafka] abstract class AbstractServerThread(connectionQuotas: ConnectionQ swallowError(key.cancel()) } } - + def close(channel: SocketChannel) { if(channel != null) { debug("Closing connection from " + channel.socket.getRemoteSocketAddress()) @@ -167,13 +181,13 @@ private[kafka] abstract class AbstractServerThread(connectionQuotas: ConnectionQ swallowError(channel.close()) } } - + /** * Close all open connections */ def closeAll() { // removes cancelled keys from selector.keys set - this.selector.selectNow() + this.selector.selectNow() val iter = this.selector.keys().iterator() while (iter.hasNext) { val key = iter.next() @@ -196,13 +210,16 @@ private[kafka] abstract class AbstractServerThread(connectionQuotas: ConnectionQ /** * Thread that accepts and configures new connections. There is only need for one of these */ -private[kafka] class Acceptor(val host: String, - val port: Int, +private[kafka] class Acceptor(val host: String, + val port: Int, private val processors: Array[Processor], - val sendBufferSize: Int, + val sendBufferSize: Int, val recvBufferSize: Int, - connectionQuotas: ConnectionQuotas) extends AbstractServerThread(connectionQuotas) { + connectionQuotas: ConnectionQuotas, + val sslConnectionConfig: SSLConnectionConfig = null, + val sslEnable: Boolean = false) extends AbstractServerThread(connectionQuotas) { val serverChannel = openServerSocket(host, port) + if(sslEnable) SSLAuthUtils.initializeSSLAuth(sslConnectionConfig) /** * Accept loop that checks for new connection attempts @@ -239,12 +256,12 @@ private[kafka] class Acceptor(val host: String, swallowError(selector.close()) shutdownComplete() } - + /* * Create a server socket to listen for connections on. */ def openServerSocket(host: String, port: Int): ServerSocketChannel = { - val socketAddress = + val socketAddress = if(host == null || host.trim.isEmpty) new InetSocketAddress(port) else @@ -256,7 +273,7 @@ private[kafka] class Acceptor(val host: String, serverChannel.socket.bind(socketAddress) info("Awaiting socket connections on %s:%d.".format(socketAddress.getHostName, port)) } catch { - case e: SocketException => + case e: SocketException => throw new KafkaException("Socket server failed to bind to %s:%d: %s.".format(socketAddress.getHostName, port, e.getMessage), e) } serverChannel @@ -277,9 +294,15 @@ private[kafka] class Acceptor(val host: String, debug("Accepted connection from %s on %s. sendBufferSize [actual|requested]: [%d|%d] recvBufferSize [actual|requested]: [%d|%d]" .format(socketChannel.socket.getInetAddress, socketChannel.socket.getLocalSocketAddress, socketChannel.socket.getSendBufferSize, sendBufferSize, - socketChannel.socket.getReceiveBufferSize, recvBufferSize)) - - processor.accept(socketChannel) + socketChannel.socket.getReceiveBufferSize, recvBufferSize)) + var channel:Channel = null + if (sslEnable) { + val sslEngine = createSSLEngine(socketChannel) + channel = new SSLChannel(socketChannel, sslEngine) + } else + channel = new Channel(socketChannel) + + processor.accept(socketChannel, channel) } catch { case e: TooManyConnectionsException => info("Rejected connection from %s, address already has the configured maximum of %d connections.".format(e.ip, e.count)) @@ -287,6 +310,16 @@ private[kafka] class Acceptor(val host: String, } } + private def createSSLEngine(socketChannel: SocketChannel) : SSLEngine = { + val sslEngine = SSLAuthUtils.sslContext.createSSLEngine(socketChannel.socket.getInetAddress.getHostName, socketChannel.socket.getPort) + //do not allow ssl, sslv2, sslv3 as they have known vulnerabilities + sslEngine.setEnabledProtocols(Array("TLSv1.2", "TLSv1.1", "TLSv1", "SSLv2Hello")) + sslEngine.setNeedClientAuth(sslConnectionConfig.needClientAuth) + sslEngine.setWantClientAuth(sslConnectionConfig.wantClientAuth) + sslEngine.setUseClientMode(false) + sslEngine + } + } /** @@ -308,6 +341,7 @@ private[kafka] class Processor(val id: Int, private var currentTimeNanos = SystemTime.nanoseconds private val lruConnections = new util.LinkedHashMap[SelectionKey, Long] private var nextIdleCloseCheckTime = currentTimeNanos + connectionsMaxIdleNanos + private val socketContainer = new ConcurrentHashMap[SocketChannel, Channel]() override def run() { startupComplete() @@ -333,17 +367,31 @@ private[kafka] class Processor(val id: Int, val iter = keys.iterator() while(iter.hasNext && isRunning) { var key: SelectionKey = null + var channel: Channel = null try { key = iter.next + channel = socketContainer.get(channelFor(key)) iter.remove() - if(key.isReadable) - read(key) - else if(key.isWritable) - write(key) - else if(!key.isValid) + if(key.isReadable) { + if (channel.isHandshakeComplete()) { + read(key) + } else { + val handShakeStatus = channel.handshake(key.isReadable(), key.isWritable()) + if (handShakeStatus == 0) key.interestOps(SelectionKey.OP_READ) else key.interestOps(handShakeStatus) + } + } else if(key.isWritable) { + if (channel.isHandshakeComplete()) + write(key) + else { + val handShakeStatus = channel.handshake(key.isReadable(), key.isWritable()) + if (handShakeStatus == 0) key.interestOps(SelectionKey.OP_READ) else key.interestOps(handShakeStatus) + } + + } else if(!key.isValid) { close(key) - else + } else { throw new IllegalStateException("Unrecognized key state for processor thread.") + } } catch { case e: EOFException => { info("Closing socket connection to %s.".format(channelFor(key).socket.getInetAddress)) @@ -371,6 +419,9 @@ private[kafka] class Processor(val id: Int, */ override def close(key: SelectionKey): Unit = { lruConnections.remove(key) + val channel = socketContainer.get(channelFor(key)) + channel.close() + socketContainer.remove(channelFor(key)) super.close(key) } @@ -414,8 +465,9 @@ private[kafka] class Processor(val id: Int, /** * Queue up a new connection for reading */ - def accept(socketChannel: SocketChannel) { + def accept(socketChannel: SocketChannel, channel: Channel) { newConnections.add(socketChannel) + socketContainer.put(socketChannel, channel) wakeup() } @@ -435,15 +487,14 @@ private[kafka] class Processor(val id: Int, */ def read(key: SelectionKey) { lruConnections.put(key, currentTimeNanos) - val socketChannel = channelFor(key) + val channel = socketContainer.get(channelFor(key)) var receive = key.attachment.asInstanceOf[Receive] if(key.attachment == null) { receive = new BoundedByteBufferReceive(maxRequestSize) key.attach(receive) } - val read = receive.readFrom(socketChannel) - val address = socketChannel.socket.getRemoteSocketAddress(); - trace(read + " bytes read from " + address) + val read = receive.readFrom(channel) + val address = channel.getIOChannel.socket.getRemoteSocketAddress(); if(read < 0) { close(key) } else if(receive.complete) { @@ -454,7 +505,7 @@ private[kafka] class Processor(val id: Int, key.interestOps(key.interestOps & (~SelectionKey.OP_READ)) } else { // more reading to be done - trace("Did not finish reading, registering for read again on connection " + socketChannel.socket.getRemoteSocketAddress()) + trace("Did not finish reading, registering for read again on connection " + channel.getIOChannel.socket.getRemoteSocketAddress()) key.interestOps(SelectionKey.OP_READ) wakeup() } @@ -464,20 +515,20 @@ private[kafka] class Processor(val id: Int, * Process writes to ready sockets */ def write(key: SelectionKey) { - val socketChannel = channelFor(key) + val channel = socketContainer.get(channelFor(key)) val response = key.attachment().asInstanceOf[RequestChannel.Response] val responseSend = response.responseSend if(responseSend == null) throw new IllegalStateException("Registered for write interest but no response attached to key.") - val written = responseSend.writeTo(socketChannel) - trace(written + " bytes written to " + socketChannel.socket.getRemoteSocketAddress() + " using key " + key) + val written = responseSend.writeTo(channel) + trace(written + " bytes written to " + channel.getIOChannel.socket.getRemoteSocketAddress() + " using key " + key) if(responseSend.complete) { response.request.updateRequestMetrics() key.attach(null) - trace("Finished writing, registering for read on connection " + socketChannel.socket.getRemoteSocketAddress()) + trace("Finished writing, registering for read on connection " + channel.getIOChannel.socket.getRemoteSocketAddress()) key.interestOps(SelectionKey.OP_READ) } else { - trace("Did not finish writing, registering for write again on connection " + socketChannel.socket.getRemoteSocketAddress()) + trace("Did not finish writing, registering for write again on connection " + channel.getIOChannel.socket.getRemoteSocketAddress()) key.interestOps(SelectionKey.OP_WRITE) wakeup() } @@ -508,7 +559,7 @@ private[kafka] class Processor(val id: Int, class ConnectionQuotas(val defaultMax: Int, overrideQuotas: Map[String, Int]) { private val overrides = overrideQuotas.map(entry => (InetAddress.getByName(entry._1), entry._2)) private val counts = mutable.Map[InetAddress, Int]() - + def inc(addr: InetAddress) { counts synchronized { val count = counts.getOrElse(addr, 0) @@ -518,7 +569,7 @@ class ConnectionQuotas(val defaultMax: Int, overrideQuotas: Map[String, Int]) { throw new TooManyConnectionsException(addr, max) } } - + def dec(addr: InetAddress) { counts synchronized { val count = counts.get(addr).get @@ -528,7 +579,7 @@ class ConnectionQuotas(val defaultMax: Int, overrideQuotas: Map[String, Int]) { counts.put(addr, count - 1) } } - + } class TooManyConnectionsException(val ip: InetAddress, val count: Int) extends KafkaException("Too many connections from %s (maximum = %d)".format(ip, count)) diff --git a/core/src/main/scala/kafka/network/ssl/SSLChannel.scala b/core/src/main/scala/kafka/network/ssl/SSLChannel.scala new file mode 100644 index 0000000..01cbd6e --- /dev/null +++ b/core/src/main/scala/kafka/network/ssl/SSLChannel.scala @@ -0,0 +1,429 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package kafka.network.ssl + +import java.io.EOFException +import java.io.IOException +import java.net.SocketTimeoutException +import java.nio.ByteBuffer +import java.nio.channels.SelectionKey +import java.nio.channels.Selector +import java.nio.channels.SocketChannel + +import javax.net.ssl.SSLEngine +import javax.net.ssl.SSLEngineResult +import javax.net.ssl.SSLEngineResult.HandshakeStatus +import javax.net.ssl.SSLEngineResult.HandshakeStatus._ +import javax.net.ssl.SSLEngineResult.Status + +import kafka.network.Channel +import kafka.utils.Logging + +class SSLChannel(socketChannel: SocketChannel, sslEngine: SSLEngine) extends Channel(socketChannel) with Logging { + val netInBuffer: ByteBuffer = ByteBuffer.allocateDirect(sslEngine.getSession.getApplicationBufferSize) + val netOutBuffer: ByteBuffer = ByteBuffer.allocateDirect(sslEngine.getSession.getPacketBufferSize) + val appBuffer: AppBufferHandler = new AppBufferHandler(sslEngine.getSession.getApplicationBufferSize, + sslEngine.getSession.getApplicationBufferSize) + + val emptyBuf: ByteBuffer = ByteBuffer.allocate(0) + var handshakeStatus: HandshakeStatus = null + var handshakeResult: SSLEngineResult = null + var handshakeComplete: Boolean = false + var closed: Boolean = false + var closing: Boolean = false + startHandshake + + /********** NIO SSL METHODS ************/ + + /** + * Start the ssl handshake. + */ + + def startHandshake { + netOutBuffer.position(0) + netOutBuffer.limit(0) + netInBuffer.position(0) + netInBuffer.limit(0) + handshakeComplete = false + closed = false + closing = false + //initiate handshake + sslEngine.beginHandshake() + handshakeStatus = sslEngine.getHandshakeStatus() + } + + /** + * Flush the channel. + */ + override def flush: Boolean = { + flush(netOutBuffer) + } + + /** + * Flushes the buffer to the network, non blocking + * @param buf ByteBuffer + * @return boolean true if the buffer has been emptied out, false otherwise + * @throws IOException + */ + def flush(buf: ByteBuffer) : Boolean = { + val remaining = buf.remaining() + if ( remaining > 0 ) { + val written = socketChannel.write(buf) + return written >= remaining + } + return true + } + + /** + * Performs SSL handshake, non blocking, but performs NEED_TASK on the same thread. + * Hence, you should never call this method using your Acceptor thread, as you would slow down + * your system significantly. + * The return for this operation is 0 if the handshake is complete and a positive value if it is not complete. + * In the event of a positive value coming back, reregister the selection key for the return values interestOps. + * @param read boolean - true if the underlying channel is readable + * @param write boolean - true if the underlying channel is writable + * @return int - 0 if hand shake is complete, otherwise it returns a SelectionKey interestOps value + * @throws IOException + */ + @throws(classOf[IOException]) + override def handshake(read: Boolean, write: Boolean): Int = { + if ( handshakeComplete ) return 0 //we have done our initial handshake + + if (!flush(netOutBuffer)) return SelectionKey.OP_WRITE //we still have data to write + handshakeStatus = sslEngine.getHandshakeStatus() + + handshakeStatus match { + case NOT_HANDSHAKING => + // SSLEnginge.getHandshakeStatus is transient and it doesn't record FINISHED status properly + if (handshakeResult.getHandshakeStatus() == FINISHED) { + handshakeComplete = !netOutBuffer.hasRemaining() + return if (handshakeComplete) 0 else SelectionKey.OP_WRITE + } else { + //should never happen + throw new IOException("NOT_HANDSHAKING during handshake") + } + + case FINISHED => + //we are complete if we have delivered the last package + handshakeComplete = !netOutBuffer.hasRemaining() + //return 0 if we are complete, otherwise we still have data to write + return if (handshakeComplete) 0 else SelectionKey.OP_WRITE + + case NEED_WRAP => + //perform the wrap function + handshakeResult = handshakeWrap(write) + if ( handshakeResult.getStatus() == Status.OK ) { + if (handshakeStatus == HandshakeStatus.NEED_TASK) + handshakeStatus = tasks() + } else { + //wrap should always work with our buffers + throw new IOException("Unexpected status [%s] during handshake WRAP.".format(handshakeResult.getStatus())) + } + if ( handshakeStatus != HandshakeStatus.NEED_UNWRAP || (!flush(netOutBuffer)) ) + return SelectionKey.OP_WRITE + + case NEED_UNWRAP => + //perform the unwrap function + handshakeResult = handshakeUnwrap(read) + if ( handshakeResult.getStatus() == Status.OK ) { + if (handshakeStatus == HandshakeStatus.NEED_TASK) + handshakeStatus = tasks() + } else if ( handshakeResult.getStatus() == Status.BUFFER_UNDERFLOW ){ + //read more data, reregister for OP_READ + return SelectionKey.OP_READ + } else { + throw new IOException("Unexpected status [%s] during handshake UNWRAP".format(handshakeStatus)) + } + + case NEED_TASK => + handshakeStatus = tasks() + case _ => throw new IllegalStateException("Unexpected status [%s]".format(handshakeStatus)) + } + //return 0 if we are complete, otherwise reregister for any activity that + //would cause this method to be called again. + return if (handshakeComplete) 0 else (SelectionKey.OP_WRITE|SelectionKey.OP_READ) + } + + /** + * Executes all the tasks needed on the same thread. + * @return HandshakeStatus + */ + def tasks(): SSLEngineResult.HandshakeStatus = { + var r: Runnable = null + while ({r = sslEngine.getDelegatedTask(); r != null}) r.run() + sslEngine.getHandshakeStatus() + } + + /** + * Performs the WRAP function + * @param doWrite boolean + * @return SSLEngineResult + * @throws IOException + */ + @throws(classOf[IOException]) + private def handshakeWrap(doWrite: Boolean): SSLEngineResult = { + //this should never be called with a network buffer that contains data + //so we can clear it here. + netOutBuffer.clear() + //perform the wrap + val result: SSLEngineResult = sslEngine.wrap(appBuffer.writeBuf, netOutBuffer) + //prepare the results to be written + netOutBuffer.flip() + //set the status + handshakeStatus = result.getHandshakeStatus() + //optimization, if we do have a writable channel, write it now + if ( doWrite ) flush(netOutBuffer) + result + } + + /** + * Perform handshake unwrap + * @param doread boolean + * @return SSLEngineResult + * @throws IOException + */ + @throws(classOf[IOException]) + private def handshakeUnwrap(doread: Boolean): SSLEngineResult = { + + if (netInBuffer.position() == netInBuffer.limit()) { + //clear the buffer if we have emptied it out on data + netInBuffer.clear() + } + if ( doread ) { + //if we have data to read, read it + val read = socketChannel.read(netInBuffer) + if (read == -1) throw new IOException("EOF during handshake.") + } + var result: SSLEngineResult = null + var cont: Boolean = false + //loop while we can perform pure SSLEngine data + do { + //prepare the buffer with the incoming data + netInBuffer.flip() + //call unwrap + result = sslEngine.unwrap(netInBuffer, appBuffer.writeBuf) + //compact the buffer, this is an optional method, wonder what would happen if we didn't + netInBuffer.compact() + //read in the status + handshakeStatus = result.getHandshakeStatus() + if ( result.getStatus() == SSLEngineResult.Status.OK && + result.getHandshakeStatus() == HandshakeStatus.NEED_TASK ) { + //execute tasks if we need to + handshakeStatus = tasks() + } + //perform another unwrap? + cont = result.getStatus() == SSLEngineResult.Status.OK && + handshakeStatus == HandshakeStatus.NEED_UNWRAP + }while ( cont ) + result + } + + override def getOutboundRemaining: Int = { + netOutBuffer.remaining() + } + + @throws(classOf[IOException]) + override def flushOutbound(): Boolean = { + val remaining = netOutBuffer.remaining() + flush(netOutBuffer) + val remaining2= netOutBuffer.remaining() + remaining2 < remaining + } + + /** + * Sends a SSL close message, will not physically close the connection here.
+ * @throws IOException if an I/O error occurs + * @throws IOException if there is data on the outgoing network buffer and we are unable to flush it + */ + + override def close() { + if (closing) return + closing = true + sslEngine.closeOutbound() + + if (!flush(netOutBuffer)) { + throw new IOException("Remaining data in the network buffer, can't send SSL close message, force a close with close(true) instead") + } + //prep the buffer for the close message + netOutBuffer.clear() + //perform the close, since we called sslEngine.closeOutbound + val handshake: SSLEngineResult = sslEngine.wrap(emptyBuf, netOutBuffer) + //we should be in a close state + if (handshake.getStatus() != SSLEngineResult.Status.CLOSED) { + throw new IOException("Invalid close state, will not send network data.") + } + //prepare the buffer for writing + netOutBuffer.flip() + //if there is data to be written + flush(netOutBuffer) + + //is the channel closed? + closed = (!netOutBuffer.hasRemaining() && (handshake.getHandshakeStatus() != HandshakeStatus.NEED_WRAP)) + } + + + /** + * Force a close, can throw an IOException + * @param force boolean + * @throws IOException + */ + @throws(classOf[IOException]) + override def close(force: Boolean) { + try { + close() + }finally { + if ( force || closed ) { + closed = true + socketChannel.socket().close() + socketChannel.close() + } + } + } + + override def isHandshakeComplete(): Boolean = { + handshakeComplete + } + + /** + * Reads a sequence of bytes from this channel into the given buffer. + * + * @param dst The buffer into which bytes are to be transferred + * @return The number of bytes read, possible zero or -1 if the channel has reached end-of-stream + * @throws IOException if some other I/O error occurs + * @throws IlleagalArgumentException if the destination buffer is different than appBufHandler.getReadBuffer() + */ + @throws(classOf[IOException]) + override def read(dst: ByteBuffer): Int = { + if (closing || closed) return -1 + if (!handshakeComplete) throw new IllegalStateException("Handshake incomplete.") + + val netread = socketChannel.read(netInBuffer) + if (netread == -1) return -1 + var read = 0 + var unwrap: SSLEngineResult = null + do { + //prepare the buffer + netInBuffer.flip() + //unwrap the data + unwrap = sslEngine.unwrap(netInBuffer, appBuffer.readBuf) + //compact the buffer + netInBuffer.compact() + if(unwrap.getStatus() == Status.OK || unwrap.getStatus() == Status.BUFFER_UNDERFLOW) { + //we did receive some data, add it to our total + read += unwrap.bytesProduced + // perform any task if needed + if(unwrap.getHandshakeStatus() == HandshakeStatus.NEED_TASK) tasks() + //if we need more network data, than return for now. + if(unwrap.getStatus() == Status.BUFFER_UNDERFLOW) return readFromAppBuffer(dst) + } else if (unwrap.getStatus() == Status.BUFFER_OVERFLOW && read > 0) { + //buffer overflow can happen, if we have read data, then + //empty out the dst buffer before we do another read + return readFromAppBuffer(dst) + } else { + //here we should trap BUFFER_OVERFLOW and call expand on the buffer + // for now, throw an exception, as we initialized the buffers + // in constructor + throw new IOException("Unable to unwrap data, invalid status [%d]".format(unwrap.getStatus())) + } + } while(netInBuffer.position() != 0) + readFromAppBuffer(dst) + } + + /** + * Writes a sequence of bytes to this channel from the given buffer. + * + * @param src The buffer from which bytes are to be retrieved + * @return The number of bytes written, possibly zero + * @throws IOException If some other I/O error occurs + */ + + @throws(classOf[IOException]) + override def write(src: ByteBuffer): Int = { + var written = 0 + if(src == this.netOutBuffer) + written = socketChannel.write(src) + else { + if (closing || closed) throw new IOException("Channel is in closing state") + + //we haven't emptied out the buffer yet + if (!flush(netOutBuffer)) + return written + netOutBuffer.clear + val result = sslEngine.wrap(src, netOutBuffer) + written = result.bytesConsumed() + netOutBuffer.flip() + if (result.getStatus() == Status.OK && result.getHandshakeStatus() == HandshakeStatus.NEED_TASK) + tasks() + else + throw new IOException("Unable to wrap data, invalid status %s".format(result.getStatus())) + flush(netOutBuffer) + } + written + } + + + @throws(classOf[IOException]) + override def write(src: Array[ByteBuffer], offset: Int, length: Int): Long = { + var totalWritten = 0 + var i = offset + for( i <- offset until length) { + if (src(i).hasRemaining) { + val written = write(src(i)) + if (written > 0) { + totalWritten += written + if(!src(i).hasRemaining) + return totalWritten + } else + return totalWritten + } + } + totalWritten + } + + @throws(classOf[IOException]) + override def write(src: Array[ByteBuffer]): Long = { + write(src, 0, src.length) + } + + private[this] def readFromAppBuffer(dst: ByteBuffer): Int = { + appBuffer.readBuf.flip() + try { + var remaining = appBuffer.readBuf.remaining + if(remaining > 0) { + if(remaining > dst.remaining) + remaining = dst.remaining + var i = 0 + while (i < remaining) { + dst.put(appBuffer.readBuf.get) + i = i + 1 + } + } + remaining + } finally { + appBuffer.readBuf.compact() + } + } + + case class AppBufferHandler(readBuf: ByteBuffer, + writeBuf: ByteBuffer) { + def this(readSize: Int, writeSize: Int) = { + this(ByteBuffer.allocate(readSize), ByteBuffer.allocate(writeSize)) + } + } + +} diff --git a/core/src/main/scala/kafka/network/ssl/SSLConnectionConfig.scala b/core/src/main/scala/kafka/network/ssl/SSLConnectionConfig.scala new file mode 100644 index 0000000..567175b --- /dev/null +++ b/core/src/main/scala/kafka/network/ssl/SSLConnectionConfig.scala @@ -0,0 +1,52 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package kafka.network.ssl + +import kafka.network.ConnectionConfig +import kafka.utils.{Utils, VerifiableProperties} + +class SSLConnectionConfig(path: String) extends ConnectionConfig { + val props = new VerifiableProperties(Utils.loadProps(path)) + + val host: String = props.getString("host", "") + + val port: Int = props.getInt("port", 9093) + + /** Keystore file location */ + val keystore = props.getString("keystore") + + /** Keystore file password */ + val keystorePwd = props.getString("keystorePwd") + + /** Keystore key password */ + val keyPwd = props.getString("keyPwd") + + /** Truststore file location */ + val truststore = props.getString("truststore") + + /** Truststore file password */ + val truststorePwd = props.getString("truststorePwd") + + val keystoreType = props.getString("keystore.type") + + /** Request client auth */ + val wantClientAuth = props.getBoolean("want.client.auth", false) + + /** Require client auth */ + val needClientAuth = props.getBoolean("need.client.auth", false) +} diff --git a/core/src/main/scala/kafka/server/KafkaConfig.scala b/core/src/main/scala/kafka/server/KafkaConfig.scala index 48e3362..ee2e8fd 100644 --- a/core/src/main/scala/kafka/server/KafkaConfig.scala +++ b/core/src/main/scala/kafka/server/KafkaConfig.scala @@ -122,6 +122,10 @@ object Defaults { val DeleteTopicEnable = false val CompressionType = "producer" + + /** ********* SSL configuration ********************/ + val SslEnable = false + val SslConfigFilePath = new String("") } object KafkaConfig { @@ -225,6 +229,8 @@ object KafkaConfig { val DeleteTopicEnableProp = "delete.topic.enable" val CompressionTypeProp = "compression.type" + val SslEnableProp = "ssl.enable" + val SslConfigFilePathProp = "ssl.connection.config.file" /* Documentation */ @@ -341,7 +347,9 @@ object KafkaConfig { val DeleteTopicEnableDoc = "Enables delete topic. Delete topic through the admin tool will have no effect if this config is turned off" val CompressionTypeDoc = "Specify the final compression type for a given topic. This configuration accepts the standard compression codecs " + "('gzip', 'snappy', lz4). It additionally accepts 'uncompressed' which is equivalent to no compression; and " + - "'producer' which means retain the original compression codec set by the producer." + "'producer' which means retain the original compression codec set by the producer." + val SslEnableDoc = "Enables SSL socket server." + val SslConfigFilePathDoc = "SSL connection config properties file path." private val configDef = { @@ -457,6 +465,10 @@ object KafkaConfig { .define(OffsetCommitRequiredAcksProp, SHORT, Defaults.OffsetCommitRequiredAcks, HIGH, OffsetCommitRequiredAcksDoc) .define(DeleteTopicEnableProp, BOOLEAN, Defaults.DeleteTopicEnable, HIGH, DeleteTopicEnableDoc) .define(CompressionTypeProp, STRING, Defaults.CompressionType, HIGH, CompressionTypeDoc) + + /** ********* SSL configuration ********************/ + .define(SslEnableProp, BOOLEAN, Defaults.SslEnable, MEDIUM, SslEnableDoc) + .define(SslConfigFilePathProp, STRING, Defaults.SslConfigFilePath, MEDIUM, SslConfigFilePathDoc) } def configNames() = { @@ -574,7 +586,12 @@ object KafkaConfig { offsetCommitTimeoutMs = parsed.get(OffsetCommitTimeoutMsProp).asInstanceOf[Int], offsetCommitRequiredAcks = parsed.get(OffsetCommitRequiredAcksProp).asInstanceOf[Short], deleteTopicEnable = parsed.get(DeleteTopicEnableProp).asInstanceOf[Boolean], - compressionType = parsed.get(CompressionTypeProp).asInstanceOf[String] + compressionType = parsed.get(CompressionTypeProp).asInstanceOf[String], + + /** ********* SSL configuration ********************/ + sslEnable = parsed.get(SslEnableProp).asInstanceOf[Boolean], + sslConfigFilePath = parsed.get(SslConfigFilePathProp).asInstanceOf[String] + ) } @@ -715,7 +732,11 @@ class KafkaConfig(/** ********* Zookeeper Configuration ***********/ val offsetCommitRequiredAcks: Short = Defaults.OffsetCommitRequiredAcks, val deleteTopicEnable: Boolean = Defaults.DeleteTopicEnable, - val compressionType: String = Defaults.CompressionType + val compressionType: String = Defaults.CompressionType, + + /** ********* SSL configuration ********************/ + val sslEnable: Boolean = Defaults.SslEnable, + val sslConfigFilePath: String = Defaults.SslConfigFilePath ) { val zkConnectionTimeoutMs: Int = _zkConnectionTimeoutMs.getOrElse(zkSessionTimeoutMs) @@ -884,6 +905,9 @@ class KafkaConfig(/** ********* Zookeeper Configuration ***********/ props.put(DeleteTopicEnableProp, deleteTopicEnable.toString) props.put(CompressionTypeProp, compressionType.toString) + /** ********* SSL configuration ********************/ + props.put(SslEnableProp, sslEnable.toString) + props.put(SslConfigFilePathProp, sslConfigFilePath.toString) props } -} \ No newline at end of file +} diff --git a/core/src/main/scala/kafka/server/KafkaServer.scala b/core/src/main/scala/kafka/server/KafkaServer.scala index dddef93..7f687a7 100644 --- a/core/src/main/scala/kafka/server/KafkaServer.scala +++ b/core/src/main/scala/kafka/server/KafkaServer.scala @@ -126,6 +126,8 @@ class KafkaServer(val config: KafkaConfig, time: Time = SystemTime) extends Logg config.socketRequestMaxBytes, config.maxConnectionsPerIp, config.connectionsMaxIdleMs, + config.sslEnable, + config.sslConfigFilePath, config.maxConnectionsPerIpOverrides) socketServer.startup() diff --git a/core/src/main/scala/kafka/utils/SSLAuthUtils.scala b/core/src/main/scala/kafka/utils/SSLAuthUtils.scala new file mode 100644 index 0000000..541c786 --- /dev/null +++ b/core/src/main/scala/kafka/utils/SSLAuthUtils.scala @@ -0,0 +1,97 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package kafka.utils + +import java.util.concurrent.atomic.AtomicBoolean +import javax.net.ssl._ +import java.io.FileInputStream + +import kafka.network.ssl.SSLConnectionConfig + + +object SSLAuthUtils extends Logging { + private var context: SSLContext = null + private val initialized = new AtomicBoolean(false) + + def sslContext: SSLContext = { + if (!initialized.get) + throw new IllegalStateException("SSL authentication is not initialized.") + context + } + + def initializeSSLAuth(config: SSLConnectionConfig) { + if(!initialized.get) { + info("Initializing SSL authentication") + val keyStore = getKeyStore(config.keystoreType) + context = keyStore.initialize(config) + initialized.set(true) + info("SSL authentication initialization has been successfully completed") + } else + warn("Attempt to reinitialize auth context") + } + + def getKeyStore(name: String): StoreInitializer = { + name.toLowerCase match { + case JKSInitializer.name => JKSInitializer + case _ => throw new RuntimeException("%s is an unknown key store".format(name)) + } + } + +} + + + +trait StoreInitializer { + def initialize(config: SSLConnectionConfig): SSLContext +} + +object JKSInitializer extends StoreInitializer { + val name = "jks" + + def initialize(config: SSLConnectionConfig): SSLContext = { + val tms = config.truststorePwd match { + case pw: String => + val ts = java.security.KeyStore.getInstance("JKS") + val fis: FileInputStream = new FileInputStream(config.truststore) + ts.load(fis, pw.toCharArray) + fis.close() + val tmf = TrustManagerFactory.getInstance("SunX509") + tmf.init(ts) + tmf.getTrustManagers + } + val kms = config.keystorePwd match { + case pw: String => + val ks = java.security.KeyStore.getInstance("JKS") + val fis: FileInputStream = new FileInputStream(config.keystore) + ks.load(fis, pw.toCharArray) + fis.close() + + val kmf = KeyManagerFactory.getInstance("SunX509") + kmf.init(ks, if (config.keyPwd != null) config.keyPwd.toCharArray else pw.toCharArray) + kmf.getKeyManagers + case _ => null + } + initContext(tms, kms) + } + + private def initContext(tms: Array[TrustManager], kms: Array[KeyManager]): SSLContext = { + val authContext = SSLContext.getInstance("TLS") + authContext.init(kms, tms, null) + authContext + } +} diff --git a/core/src/test/scala/unit/kafka/network/SocketServerTest.scala b/core/src/test/scala/unit/kafka/network/SocketServerTest.scala index 0af23ab..cdd5bc5 100644 --- a/core/src/test/scala/unit/kafka/network/SocketServerTest.scala +++ b/core/src/test/scala/unit/kafka/network/SocketServerTest.scala @@ -18,18 +18,20 @@ package kafka.network; import java.net._ +import javax.net.ssl._ import java.io._ import org.junit._ import org.scalatest.junit.JUnitSuite import java.util.Random import junit.framework.Assert._ +import kafka.network.ssl.SSLConnectionConfig import kafka.producer.SyncProducerConfig import kafka.api.ProducerRequest import java.nio.ByteBuffer import kafka.common.TopicAndPartition import kafka.message.ByteBufferMessageSet import java.nio.channels.SelectionKey -import kafka.utils.TestUtils +import kafka.utils.{TestUtils, TestSSLUtils} import scala.collection.Map class SocketServerTest extends JUnitSuite { @@ -79,7 +81,7 @@ class SocketServerTest extends JUnitSuite { def cleanup() { server.shutdown() } - @Test + @Test def simpleRequest() { val socket = connect() val correlationId = -1 @@ -180,4 +182,35 @@ class SocketServerTest extends JUnitSuite { assertEquals(-1, conn.getInputStream.read()) overrideServer.shutdown() } + + @Test + def testSslSocketServer() { + val sslConfigFile = TestSSLUtils.createSslConfigFile() + val overrideServer: SocketServer = new SocketServer(0, + host = null, + port = kafka.utils.TestUtils.choosePort, + numProcessorThreads = 1, + maxQueuedRequests = 50, + sendBufferSize = 300000, + recvBufferSize = 300000, + maxRequestSize = 50, + maxConnectionsPerIp = 5, + connectionsMaxIdleMs = 60*1000, + sslEnable = true, + sslConfigFilePath = sslConfigFile.getPath(), + maxConnectionsPerIpOverrides = Map.empty[String,Int]) + overrideServer.startup() + val sslConnectionConfig = new SSLConnectionConfig(sslConfigFile.getPath) + val sslContext = SSLContext.getInstance("TLSv1") + sslContext.init(null, Array(TestSSLUtils.trustAllCerts), new java.security.SecureRandom()) + val socketFactory = sslContext.getSocketFactory + val socket = socketFactory.createSocket("localhost", sslConnectionConfig.port).asInstanceOf[SSLSocket] + socket.setNeedClientAuth(false) + val bytes = new Array[Byte](40) + // send a request first to make sure the connection has been picked up by the socket server + sendRequest(socket, 0, bytes) + processRequest(overrideServer.requestChannel) + server.shutdown() + } + } diff --git a/core/src/test/scala/unit/kafka/server/KafkaConfigConfigDefTest.scala b/core/src/test/scala/unit/kafka/server/KafkaConfigConfigDefTest.scala index c124c8d..88da23f 100644 --- a/core/src/test/scala/unit/kafka/server/KafkaConfigConfigDefTest.scala +++ b/core/src/test/scala/unit/kafka/server/KafkaConfigConfigDefTest.scala @@ -146,6 +146,9 @@ class KafkaConfigConfigDefTest extends JUnit3Suite { Assert.assertEquals(expectedConfig.deleteTopicEnable, actualConfig.deleteTopicEnable) Assert.assertEquals(expectedConfig.compressionType, actualConfig.compressionType) + + Assert.assertEquals(expectedConfig.sslEnable, actualConfig.sslEnable) + Assert.assertEquals(expectedConfig.sslConfigFilePath, actualConfig.sslConfigFilePath) } private def atLeastXIntProp(x: Int): String = (nextInt(Int.MaxValue - 1) + x).toString @@ -237,6 +240,7 @@ class KafkaConfigConfigDefTest extends JUnit3Suite { //BrokerCompressionCodec.isValid(compressionType) case KafkaConfig.CompressionTypeProp => expected.setProperty(name, randFrom(BrokerCompressionCodec.brokerCompressionOptions)) + case KafkaConfig.SslEnableProp => expected.setProperty(name, randFrom("true", "false")) case nonNegativeIntProperty => expected.setProperty(name, nextInt(Int.MaxValue).toString) } }) @@ -339,7 +343,8 @@ class KafkaConfigConfigDefTest extends JUnit3Suite { case KafkaConfig.OffsetCommitRequiredAcksProp => assertPropertyInvalid(getBaseProperties(), name, "not_a_number", "-2") case KafkaConfig.DeleteTopicEnableProp => assertPropertyInvalid(getBaseProperties(), name, "not_a_boolean", "0") - + case KafkaConfig.SslEnableProp => assertPropertyInvalid(getBaseProperties(), name, "not_a_boolean", "0") + case KafkaConfig.SslConfigFilePathProp => // ignore string case nonNegativeIntProperty => assertPropertyInvalid(getBaseProperties(), name, "not_a_number", "-1") } }) diff --git a/core/src/test/scala/unit/kafka/utils/TestSSLUtils.scala b/core/src/test/scala/unit/kafka/utils/TestSSLUtils.scala new file mode 100644 index 0000000..8995ce4 --- /dev/null +++ b/core/src/test/scala/unit/kafka/utils/TestSSLUtils.scala @@ -0,0 +1,178 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package kafka.utils + +import sun.security.x509._ +import java.security._ +import java.security.GeneralSecurityException +import java.security.Key +import java.security.KeyPair +import java.security.KeyPairGenerator +import java.security.KeyStore +import java.security.NoSuchAlgorithmException +import java.security.PrivateKey +import java.security.SecureRandom +import java.security.cert.Certificate +import java.security.cert.X509Certificate +import javax.net.ssl.X509TrustManager + +import java.io.File +import java.io.FileOutputStream +import java.io.FileWriter +import java.io.IOException +import java.io.Writer +import java.math.BigInteger +import java.net.URL +import java.util.Date +import java.util.Properties +import scala.collection.mutable.HashMap + + + +/** + * SSL utility functions to help with testing + */ +object TestSSLUtils extends Logging { + + @throws(classOf[GeneralSecurityException]) + @throws(classOf[IOException]) + def generateCertificate(dn: String, pair: KeyPair, days: Int, algorithm: String): X509Certificate = { + val privateKey = pair.getPrivate() + val info = new X509CertInfo() + val from = new Date() + val to = new Date(from.getTime() + days * 86400000l) + val interval = new CertificateValidity(from, to) + val sn = new BigInteger(64, new SecureRandom()) + val owner = new X500Name(dn) + + info.set(X509CertInfo.VALIDITY, interval) + info.set(X509CertInfo.SERIAL_NUMBER, new CertificateSerialNumber(sn)) + info.set(X509CertInfo.SUBJECT, new CertificateSubjectName(owner)) + info.set(X509CertInfo.ISSUER, new CertificateIssuerName(owner)) + info.set(X509CertInfo.KEY, new CertificateX509Key(pair.getPublic())) + info.set(X509CertInfo.VERSION, new CertificateVersion(CertificateVersion.V3)) + var algo = new AlgorithmId(AlgorithmId.md5WithRSAEncryption_oid) + info.set(X509CertInfo.ALGORITHM_ID, new CertificateAlgorithmId(algo)) + + //sign the cert to identify the algorithm that's used. + var cert = new X509CertImpl(info) + cert.sign(privateKey, algorithm) + + //update the algorithm, and resign. + algo = cert.get(X509CertImpl.SIG_ALG).asInstanceOf[AlgorithmId] + info.set(CertificateAlgorithmId.NAME + "." + CertificateAlgorithmId.ALGORITHM, algo) + cert = new X509CertImpl(info) + cert.sign(privateKey, algorithm) + cert + } + + def generateKeyPair(algorithm: String): KeyPair = { + val keyGen = KeyPairGenerator.getInstance(algorithm) + keyGen.initialize(1024) + keyGen.genKeyPair + } + + @throws(classOf[GeneralSecurityException]) + @throws(classOf[IOException]) + def createEmptyKeyStore: KeyStore = { + val ks = KeyStore.getInstance("JKS") + ks.load(null, null) + ks + } + + @throws(classOf[GeneralSecurityException]) + @throws(classOf[IOException]) + private def saveKeyStore(ks: KeyStore, filename: String, password: String) { + val out = new FileOutputStream(filename) + try { + ks.store(out, password.toCharArray) + } finally { + out.close() + } + } + + @throws(classOf[GeneralSecurityException]) + @throws(classOf[IOException]) + def createKeyStore(filename: String, password: String, keyPassword: String, alias: String, + privateKey: Key, cert: Certificate) { + val ks = createEmptyKeyStore + ks.setKeyEntry(alias, privateKey, keyPassword.toCharArray, Array(cert)) + saveKeyStore(ks, filename, password) + } + + @throws(classOf[GeneralSecurityException]) + @throws(classOf[IOException]) + def createTrustStore(filename: String, password: String, certs: HashMap[String, X509Certificate]) { + val ks = createEmptyKeyStore + certs foreach {case (key, cert) => ks.setCertificateEntry(key, cert)} + saveKeyStore(ks, filename, password) + } + + // a X509TrustManager to trust self-signed certs for unit tests. + def trustAllCerts: X509TrustManager = { + val trustManager = new X509TrustManager() { + override def getAcceptedIssuers: Array[X509Certificate] = { + null + } + override def checkClientTrusted(certs: Array[X509Certificate], authType: String) { + } + override def checkServerTrusted(certs: Array[X509Certificate], authType: String) { + } + } + trustManager + } + + /** + * Creates a temporary ssl config file + */ + def createSslConfigFile(): File = { + val certs = new HashMap[String, X509Certificate] + val keyPair = generateKeyPair("RSA") + val cert = generateCertificate("CN=localhost, O=localhost", keyPair, 30, "SHA1withRSA") + val password = "test" + + val keyStoreFile = File.createTempFile("keystore", ".jks") + createKeyStore(keyStoreFile.getPath(), password, password, "localhost", keyPair.getPrivate, cert) + certs.put("localhost", cert) + val trustStoreFile = File.createTempFile("truststore", ".jks") + createTrustStore(trustStoreFile.getPath(), password, certs) + + val f = File.createTempFile("ssl.server",".properties") + val sslProps = new Properties + sslProps.put("port", TestUtils.choosePort().toString) + sslProps.put("host", "localhost") + sslProps.put("keystore.type", "jks") + sslProps.put("want.client.auth", "false") + sslProps.put("need.client.auth", "false") + sslProps.put("keystore", keyStoreFile.getPath) + sslProps.put("keystorePwd", password) + sslProps.put("keyPwd", password) + sslProps.put("truststore", trustStoreFile.getPath) + sslProps.put("truststorePwd", password) + val outputStream = new FileOutputStream(f) + sslProps.store(outputStream, "") + outputStream.flush() + outputStream.getFD().sync() + outputStream.close() + f.deleteOnExit() + keyStoreFile.deleteOnExit() + trustStoreFiel.deleteOnExit() + f + } + +} -- 1.9.5 (Apple Git-50.3) From 2c8cb07e2c684867297a2b8ef5d0fa9dae38a711 Mon Sep 17 00:00:00 2001 From: Sriharsha Chintalapani Date: Wed, 11 Mar 2015 14:29:04 -0700 Subject: [PATCH 2/2] KAFKA-1684. Implement TLS/SSL authentication. --- core/src/test/scala/unit/kafka/utils/TestSSLUtils.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/core/src/test/scala/unit/kafka/utils/TestSSLUtils.scala b/core/src/test/scala/unit/kafka/utils/TestSSLUtils.scala index 8995ce4..ed53c04 100644 --- a/core/src/test/scala/unit/kafka/utils/TestSSLUtils.scala +++ b/core/src/test/scala/unit/kafka/utils/TestSSLUtils.scala @@ -171,7 +171,7 @@ object TestSSLUtils extends Logging { outputStream.close() f.deleteOnExit() keyStoreFile.deleteOnExit() - trustStoreFiel.deleteOnExit() + trustStoreFile.deleteOnExit() f } -- 1.9.5 (Apple Git-50.3)