diff --git a/clients/src/main/java/org/apache/kafka/clients/producer/KafkaProducer.java b/clients/src/main/java/org/apache/kafka/clients/producer/KafkaProducer.java index f58b8508d3f813a51015abed772c704390887d7e..bfe37472f677d24a4da04d1c02ada1a14871e99d 100644 --- a/clients/src/main/java/org/apache/kafka/clients/producer/KafkaProducer.java +++ b/clients/src/main/java/org/apache/kafka/clients/producer/KafkaProducer.java @@ -123,7 +123,7 @@ public class KafkaProducer implements Producer { List addresses = ClientUtils.parseAndValidateAddresses(config.getList(ProducerConfig.BOOTSTRAP_SERVERS_CONFIG)); this.metadata.update(Cluster.bootstrap(addresses), time.milliseconds()); - NetworkClient client = new NetworkClient(new Selector(this.metrics, time), + NetworkClient client = new NetworkClient(new Selector(this.metrics, time, false), this.metadata, clientId, config.getInt(ProducerConfig.MAX_IN_FLIGHT_REQUESTS_PER_CONNECTION), diff --git a/clients/src/main/java/org/apache/kafka/clients/producer/ProducerConfig.java b/clients/src/main/java/org/apache/kafka/clients/producer/ProducerConfig.java index f9de4af426449cceca12a8de9a9f54a6241d28d8..f462f2f84474cff0261429d5859b412daf4ae74a 100644 --- a/clients/src/main/java/org/apache/kafka/clients/producer/ProducerConfig.java +++ b/clients/src/main/java/org/apache/kafka/clients/producer/ProducerConfig.java @@ -171,6 +171,14 @@ public class ProducerConfig extends AbstractConfig { public static final String MAX_IN_FLIGHT_REQUESTS_PER_CONNECTION = "max.in.flight.requests.per.connection"; private static final String MAX_IN_FLIGHT_REQUESTS_PER_CONNECTION_DOC = "The maximum number of unacknowledged requests the client will send on a single connection before blocking."; + /** secure */ + public static final String SECURE = "secure"; + private static final String SECURE_DOC = "Determines whether use SSL of not."; + + /** security.config.file */ + public static final String SECURITY_CONFIG_FILE = "security.config.file"; + private static final String SECURITY_CONFIG_FILE_DOC = "Determines whether use SSL of not."; + static { config = new ConfigDef().define(BOOTSTRAP_SERVERS_CONFIG, Type.LIST, Importance.HIGH, BOOSTRAP_SERVERS_DOC) .define(BUFFER_MEMORY_CONFIG, Type.LONG, 32 * 1024 * 1024L, atLeast(0L), Importance.HIGH, BUFFER_MEMORY_DOC) @@ -212,7 +220,17 @@ public class ProducerConfig extends AbstractConfig { 5, atLeast(1), Importance.LOW, - MAX_IN_FLIGHT_REQUESTS_PER_CONNECTION_DOC); + MAX_IN_FLIGHT_REQUESTS_PER_CONNECTION_DOC) + .define(SECURE, + Type.BOOLEAN, + false, + Importance.LOW, + SECURE_DOC) + .define(SECURITY_CONFIG_FILE, + Type.STRING, + "config/client.security.properties", + Importance.LOW, + SECURITY_CONFIG_FILE_DOC); } ProducerConfig(Map props) { diff --git a/clients/src/main/java/org/apache/kafka/clients/producer/internals/SSLSocketChannel.java b/clients/src/main/java/org/apache/kafka/clients/producer/internals/SSLSocketChannel.java new file mode 100644 index 0000000000000000000000000000000000000000..82078fbd93afe39b8ed21b02cc2662d315583cd7 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/clients/producer/internals/SSLSocketChannel.java @@ -0,0 +1,656 @@ +package org.apache.kafka.clients.producer.internals; + +import org.apache.kafka.common.network.security.SecureAuth; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import javax.net.ssl.SSLEngine; +import javax.net.ssl.SSLEngineResult; +import javax.net.ssl.SSLException; +import java.io.IOException; +import java.net.Socket; +import java.net.SocketAddress; +import java.net.SocketOption; +import java.nio.ByteBuffer; +import java.nio.channels.ClosedChannelException; +import java.nio.channels.SelectionKey; +import java.nio.channels.Selector; +import java.nio.channels.SocketChannel; +import java.util.Set; +import java.util.concurrent.SynchronousQueue; +import java.util.concurrent.ThreadFactory; +import java.util.concurrent.ThreadPoolExecutor; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicInteger; + +import static javax.net.ssl.SSLEngineResult.HandshakeStatus; +import static javax.net.ssl.SSLEngineResult.HandshakeStatus.*; +import static javax.net.ssl.SSLEngineResult.Status.CLOSED; + +public class SSLSocketChannel extends SocketChannel { + private static final Logger log = LoggerFactory.getLogger(SSLSocketChannel.class); + + private final SSLSocketChannel outer = this; + private final SSLEngine sslEngine; + private final boolean simulateSlowNetwork = false; + private final int runningTasks = -2; + private final AtomicInteger counter = new AtomicInteger(0); + private final ThreadPoolExecutor executor = new ThreadPoolExecutor(2, 10, + 60L, TimeUnit.SECONDS, + new SynchronousQueue(), + new ThreadFactory() { + public Thread newThread(Runnable r) { + Thread thread = new Thread(r, String.format("SSLSession-Task-Thread-%d", + counter.incrementAndGet())); + thread.setDaemon(true); + return thread; + } + } + ); + + private HandshakeStatus handshakeStatus = NOT_HANDSHAKING; + + private SocketChannel underlying; + + private volatile int initialized = -1; + private boolean shutdown = false; + + private ByteBuffer peerAppData; + private ByteBuffer myNetData; + private ByteBuffer peerNetData; + private ByteBuffer emptyBuffer; + + private boolean blocking = false; + private Selector blockingSelector; + private SelectionKey blockingKey = null; + private volatile SelectionKey selectionKey = null; + + public SSLSocketChannel(SocketChannel underlying, SSLEngine sslEngine) throws IOException { + super(underlying.provider()); + this.underlying = underlying; + this.sslEngine = sslEngine; + this.peerAppData = ByteBuffer.allocate(sslEngine.getSession().getApplicationBufferSize()); + this.myNetData = ByteBuffer.allocate(sslEngine.getSession().getPacketBufferSize()); + this.peerNetData = ByteBuffer.allocate(sslEngine.getSession().getPacketBufferSize()); + this.emptyBuffer = ByteBuffer.allocate(0); + + myNetData.limit(0); + underlying.configureBlocking(false); + + this.blockingSelector = Selector.open(); + } + + public static SSLSocketChannel create(SocketChannel sch, String host, int port) throws IOException { + SSLEngine engine = SecureAuth.getSSLContext().createSSLEngine(host, port); + engine.setEnabledProtocols(new String[]{"SSLv3"}); + engine.setUseClientMode(true); + return new SSLSocketChannel(sch, engine); + } + + public void simulateBlocking(Boolean b) { + blocking = b; + } + + public Socket socket() { + return underlying.socket(); + } + + public boolean isConnected() { + return underlying.isConnected(); + } + + public boolean isConnectionPending() { + return underlying.isConnectionPending(); + } + + public boolean connect(SocketAddress remote) throws IOException { + boolean ret = underlying.connect(remote); + if (blocking) { + while (!finishConnect()) { + try { + Thread.sleep(10); + } catch (InterruptedException ignore) { + } + } + blockingKey = underlying.register(blockingSelector, SelectionKey.OP_READ); + try { + handshakeInBlockMode(SelectionKey.OP_WRITE); + } catch (InterruptedException e) { + throw new RuntimeException(e); + } + return true; + } else { + return ret; + } + } + + public boolean finishConnect() throws IOException { + return underlying.finishConnect(); + } + + public boolean finished() { + return initialized == 0; + } + + public boolean isReadable() { + return finished() && (peerAppData.position() > 0 || peerNetData.position() > 0); + } + + public synchronized int read(ByteBuffer dst) throws IOException { + if (peerAppData.position() >= dst.remaining()) { + return readFromPeerData(dst); + } else if (underlying.socket().isInputShutdown()) { + throw new ClosedChannelException(); + } else if (initialized != 0) { + try { + handshake(SelectionKey.OP_READ, selectionKey); + } catch (InterruptedException e) { + throw new RuntimeException(e); + } + return 0; + } else if (shutdown) { + shutdown(); + return -1; + } else if (sslEngine.isInboundDone()) { + return -1; + } else { + int count = (int) readRaw(); + if (count <= 0 && peerNetData.position() == 0) return count; + } + + try { + if (unwrap(false) < 0) return -1; + } catch (InterruptedException e) { + throw new RuntimeException(e); + } + + return readFromPeerData(dst); + } + + public long read(ByteBuffer[] destination, int offset, int length) throws IOException { + int n = 0; + int i = offset; + + while (i < length) { + if (destination[i].hasRemaining()) { + int x = read(destination[i]); + if (x > 0) { + n += x; + if (!destination[i].hasRemaining()) { + break; + } + } else { + if ((x < 0) && (n == 0)) { + n = -1; + } + break; + } + } + i = i + 1; + } + + return n; + } + + public synchronized int write(ByteBuffer source) throws IOException { + if (myNetData.hasRemaining()) { + writeRaw(myNetData); + return 0; + } else if (underlying.socket().isOutputShutdown()) { + throw new ClosedChannelException(); + } else if (initialized != 0) { + try { + handshake(SelectionKey.OP_WRITE, selectionKey); + } catch (InterruptedException e) { + throw new RuntimeException(e); + } + return 0; + } else if (shutdown) { + shutdown(); + return -1; + } + + int written; + try { + written = wrap(source); + } catch (InterruptedException e) { + throw new RuntimeException(e); + } + + while (myNetData.hasRemaining()) + writeRaw(myNetData); + + return written; + } + + public long write(ByteBuffer[] sources, int offset, int length) throws IOException { + int n = 0; + int i = offset; + while (i < length) { + if (sources[i].hasRemaining()) { + int x = write(sources[i]); + if (x > 0) { + n += x; + if (!sources[i].hasRemaining()) { + return 0; + } + } else { + return 0; + } + } + i = i + 1; + } + return n; + } + + @Override + public String toString() { + return "SSLSocketChannel[" + underlying.toString() + "]"; + } + + protected void implCloseSelectableChannel() throws IOException { + try { + _shutdown(); + } catch (Exception ignore) { + } + underlying.close(); + } + + protected void implConfigureBlocking(boolean block) throws IOException { + simulateBlocking(block); + if (!block) underlying.configureBlocking(block); + } + + public synchronized int handshake(int o, SelectionKey key) throws IOException, InterruptedException { + if (initialized == 0) return initialized; + + if (selectionKey == null) selectionKey = key; + + if (initialized != -1) { + if (writeIfReadyAndNeeded(o, false)) return o; + } + int init = localHandshake(o); + if (init != runningTasks) { + initialized = init; + } + return init; + } + + private boolean writeIfReadyAndNeeded(int o, boolean mustWrite) throws IOException { + if ((o & SelectionKey.OP_WRITE) != 0) { + writeRaw(myNetData); + return myNetData.remaining() > 0; + } else { + return mustWrite; + } + } + + private boolean readIfReadyAndNeeded(int o, boolean mustRead) throws IOException, InterruptedException { + if ((o & SelectionKey.OP_READ) != 0) { + if (readRaw() < 0) { + shutdown = true; + underlying.close(); + return true; + } + int oldPos = peerNetData.position(); + unwrap(true); + return oldPos == peerNetData.position(); + } else { + return mustRead; + } + } + + private int localHandshake(int o) throws IOException, InterruptedException { + while (!Thread.currentThread().isInterrupted()) { + switch (handshakeStatus) { + case NOT_HANDSHAKING: + sslEngine.beginHandshake(); + handshakeStatus = sslEngine.getHandshakeStatus(); + break; + case NEED_UNWRAP: + if (readIfReadyAndNeeded(o, true) && handshakeStatus != FINISHED) { + return SelectionKey.OP_READ; + } + break; + case NEED_WRAP: + if (myNetData.remaining() == 0) { + wrap(emptyBuffer); + } + if (writeIfReadyAndNeeded(o, true)) { + return SelectionKey.OP_WRITE; + } + break; + case NEED_TASK: + handshakeStatus = runTasks(SelectionKey.OP_READ); + break; + case FINISHED: + underlying.socket().getLocalSocketAddress(); + return 0; + default: + if (handshakeStatus == null) return runningTasks; + } + } + + return o; + } + + public void shutdown() throws IOException { + _shutdown(); + underlying.close(); + } + + synchronized private void _shutdown() throws IOException { + shutdown = true; + + try { + if (!sslEngine.isOutboundDone()) sslEngine.closeOutbound(); + + myNetData.compact(); + while (!sslEngine.isOutboundDone()) { + SSLEngineResult res = sslEngine.wrap(emptyBuffer, myNetData); + if (res.getStatus() != CLOSED) { + throw new SSLException(String.format("Unexpected shutdown status '%s'", res.getStatus())); + } + + myNetData.flip(); + try { + while (myNetData.hasRemaining()) + writeRaw(myNetData); + } catch (IOException ignore) { + } + } + } finally { + if (blockingKey != null) { + try { + blockingKey.cancel(); + } finally { + blockingKey = null; + blockingSelector.close(); + } + } + } + } + + private int handshakeInBlockMode(int ops) throws InterruptedException, IOException { + int o = ops; + while (o != 0) { + int tops = handshake(o, null); + if (tops == o) { + try { + Thread.sleep(10); + } catch (Throwable e) { + //case _: => InterruptedException + } + } else { + o = tops; + } + } + return o; + } + + private void blockIfNeeded() { + if (blockingKey != null) { + try { + blockingSelector.select(5000); + } catch (Throwable e) { + //error("Unexpected error in blocking select", t) + } + } + } + + private synchronized long readRaw() throws IOException { + blockIfNeeded(); + try { + int n = underlying.read(peerNetData); + if (n < 0) { + sslEngine.closeInbound(); + } + + return n; + } catch (IOException e) { + sslEngine.closeInbound(); + throw e; + } + } + + private int unwrap(boolean isHandshaking) throws IOException, InterruptedException { + int pos = peerAppData.position(); + peerNetData.flip(); + try { + while (peerNetData.hasRemaining()) { + SSLEngineResult result = sslEngine.unwrap(peerNetData, peerAppData); + handshakeStatus = result.getHandshakeStatus(); + switch(result.getStatus()) { + case OK: + if (handshakeStatus == NEED_TASK) { + handshakeStatus = runTasks(SelectionKey.OP_READ); + if (handshakeStatus == null) return 0; + } + if (isHandshaking && handshakeStatus == HandshakeStatus.FINISHED) { + return peerAppData.position() - pos; + } + case BUFFER_OVERFLOW: + peerAppData = expand(peerAppData, sslEngine.getSession().getApplicationBufferSize()); + break; + case BUFFER_UNDERFLOW: + return 0; + case CLOSED: + if (peerAppData.position() == 0) { + shutdown(); + return -1; + } else { + shutdown = true; + return 0; + } + default: throw new SSLException("Unexpected state!"); + } + } + } finally { + peerNetData.compact(); + } + + return peerAppData.position() - pos; + } + + private int wrap(ByteBuffer src) throws IOException, InterruptedException { + int written = src.remaining(); + myNetData.compact(); + try { + do { + SSLEngineResult result = sslEngine.wrap(src, myNetData); + handshakeStatus = result.getHandshakeStatus(); + switch(result.getStatus()) { + case OK: + if (handshakeStatus == NEED_TASK) { + handshakeStatus = runTasks(SelectionKey.OP_READ); + if (handshakeStatus == null) return 0; + } + break; + case BUFFER_OVERFLOW: + int size = (src.remaining() * 2 > sslEngine.getSession().getApplicationBufferSize()) ? + src.remaining() * 2 : + sslEngine.getSession().getApplicationBufferSize(); + myNetData = expand(myNetData, size); + break; + case CLOSED: + shutdown(); + throw new IOException("Write error received Status.CLOSED"); + default: throw new SSLException("Unexpected state!"); + } + } while (src.hasRemaining()); + } finally { + myNetData.flip(); + } + return written; + } + + private long writeRaw(ByteBuffer out) throws IOException { + try { + if (out.hasRemaining()) { + return underlying.write( (simulateSlowNetwork) ? writeTwo(out) : out ); + } else { + return 0; + } + } catch (IOException e) { + sslEngine.closeOutbound(); + shutdown = true; + throw e; + } + } + + private ByteBuffer writeTwo(ByteBuffer i) { + ByteBuffer o = ByteBuffer.allocate(2); + int rem = i.limit() - i.position(); + if (rem > o.capacity()) rem = o.capacity(); + int c = 0; + while (c < rem) { + o.put(i.get()); + c += 1; + } + o.flip(); + return o; + } + + private HandshakeStatus runTasks(int ops) throws IOException, InterruptedException { + boolean reInitialize; + switch (initialized) { + case 0: + initialized = ops; + reInitialize = true; + break; + default: + reInitialize = false; + } + + Runnable runnable = sslEngine.getDelegatedTask(); + if (!blocking && selectionKey != null) { + if (runnable != null) { + executor.execute(new SSLTasker(runnable)); + } + return null; + } else { + while (runnable != null) { + runnable.run(); + runnable = sslEngine.getDelegatedTask(); + } + if (reInitialize) { + handshakeInBlockMode(ops); + } + return sslEngine.getHandshakeStatus(); + } + } + + private ByteBuffer expand(ByteBuffer src, int ensureSize) { + if (src.remaining() < ensureSize) { + ByteBuffer newBuffer = ByteBuffer.allocate(src.capacity() + ensureSize); + if (src.position() > 0) { + src.flip(); + newBuffer.put(src); + } + return newBuffer; + } else { + return src; + } + } + + private int readFromPeerData(ByteBuffer dest) { + peerAppData.flip(); + try { + int remaining = peerAppData.remaining(); + if (remaining > 0) { + if (remaining > dest.remaining()) { + remaining = dest.remaining(); + } + int i = 0; + while (i < remaining) { + dest.put(peerAppData.get()); + i = i + 1; + } + } + return remaining; + } finally { + peerAppData.compact(); + } + } + + public SocketChannel bind(SocketAddress local) throws IOException { + return underlying.bind(local); + } + + public SocketChannel shutdownInput() { + return shutdownInput(); + } + + public SocketChannel setOption(SocketOption name, T value) throws IOException { + return underlying.setOption(name, value); + } + + public T getOption(SocketOption name) throws IOException { + return underlying.getOption(name); + } + + public SocketAddress getRemoteAddress() throws IOException { + return underlying.getRemoteAddress(); + } + + public SocketChannel shutdownOutput() throws IOException { + return underlying.shutdownOutput(); + } + + public SocketAddress getLocalAddress() throws IOException { + return underlying.getLocalAddress(); + } + + public Set> supportedOptions() { + return underlying.supportedOptions(); + } + + private class SSLTasker implements Runnable { + private Runnable runnable; + + private SSLTasker(Runnable runnable) { + this.runnable = runnable; + selectionKey.interestOps(0); + } + + public void run() { + try { + runnable.run(); + synchronized(outer) { + handshakeStatus = sslEngine.getHandshakeStatus(); + switch(handshakeStatus) { + case NEED_WRAP: + selectionKey.interestOps(SelectionKey.OP_WRITE); + break; + case NEED_UNWRAP: + if (peerNetData.position() > 0) { + int init = outer.handshake(SelectionKey.OP_READ, selectionKey); + if (init == 0) { + selectionKey.interestOps(SelectionKey.OP_READ); + } else if (init != runningTasks) { + selectionKey.interestOps(init); + } + } else { + selectionKey.interestOps(SelectionKey.OP_READ); + } + case NEED_TASK: + Runnable runnable = sslEngine.getDelegatedTask(); + if (runnable != null) { + executor.execute(new SSLTasker(runnable)); + handshakeStatus = null; + } + return; + default: throw new SSLException("unexpected handshakeStatus: " + handshakeStatus); + } + selectionKey.selector().wakeup(); + } + } catch (InterruptedException e) { + log.error(e.getMessage(), e); + } catch (SSLException e) { + log.error(e.getMessage(), e); + } catch (IOException e) { + log.error(e.getMessage(), e); + } + } + } +} diff --git a/clients/src/main/java/org/apache/kafka/common/errors/UnknownKeyStoreException.java b/clients/src/main/java/org/apache/kafka/common/errors/UnknownKeyStoreException.java new file mode 100644 index 0000000000000000000000000000000000000000..b9d07c8d2ab81ae90a622f720d4692d433578ec2 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/errors/UnknownKeyStoreException.java @@ -0,0 +1,11 @@ +package org.apache.kafka.common.errors; + +/** + * Created by ilyutov on 6/5/14. + */ +public class UnknownKeyStoreException extends Exception { + + public UnknownKeyStoreException(String message) { + super(message); + } +} diff --git a/clients/src/main/java/org/apache/kafka/common/network/Selector.java b/clients/src/main/java/org/apache/kafka/common/network/Selector.java index 4dd2cdf773f7eb01a93d7f994383088960303dfc..73a6b332aa375ba4ae55365414a06b93d6f8d0fa 100644 --- a/clients/src/main/java/org/apache/kafka/common/network/Selector.java +++ b/clients/src/main/java/org/apache/kafka/common/network/Selector.java @@ -20,6 +20,7 @@ import java.nio.channels.CancelledKeyException; import java.nio.channels.SelectionKey; import java.nio.channels.SocketChannel; import java.nio.channels.UnresolvedAddressException; +import java.nio.channels.spi.SelectorProvider; import java.util.ArrayList; import java.util.HashMap; import java.util.Iterator; @@ -28,6 +29,7 @@ import java.util.Map; import java.util.Set; import java.util.concurrent.TimeUnit; +import org.apache.kafka.clients.producer.internals.SSLSocketChannel; import org.apache.kafka.common.KafkaException; import org.apache.kafka.common.metrics.Measurable; import org.apache.kafka.common.metrics.MetricConfig; @@ -37,6 +39,7 @@ import org.apache.kafka.common.metrics.stats.Avg; import org.apache.kafka.common.metrics.stats.Count; import org.apache.kafka.common.metrics.stats.Max; import org.apache.kafka.common.metrics.stats.Rate; +import org.apache.kafka.common.network.security.SecureAuth; import org.apache.kafka.common.utils.Time; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -81,11 +84,12 @@ public class Selector implements Selectable { private final List connected; private final Time time; private final SelectorMetrics sensors; + private final boolean secure; /** * Create a new selector */ - public Selector(Metrics metrics, Time time) { + public Selector(Metrics metrics, Time time, boolean secure) { try { this.selector = java.nio.channels.Selector.open(); } catch (IOException e) { @@ -98,6 +102,7 @@ public class Selector implements Selectable { this.connected = new ArrayList(); this.disconnected = new ArrayList(); this.sensors = new SelectorMetrics(metrics); + this.secure = secure; } /** @@ -118,7 +123,7 @@ public class Selector implements Selectable { if (this.keys.containsKey(id)) throw new IllegalStateException("There is already a connection for id " + id); - SocketChannel channel = SocketChannel.open(); + SocketChannel channel = secure ? SSLSocketChannel.create(SocketChannel.open(), address.getHostString(), address.getPort()) : SocketChannel.open(); channel.configureBlocking(false); Socket socket = channel.socket(); socket.setKeepAlive(true); diff --git a/clients/src/main/java/org/apache/kafka/common/network/security/AuthConfig.java b/clients/src/main/java/org/apache/kafka/common/network/security/AuthConfig.java new file mode 100644 index 0000000000000000000000000000000000000000..99a35dfae226a202a1e669cf59c4f4341837fbda --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/network/security/AuthConfig.java @@ -0,0 +1,68 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.common.network.security; + +import java.io.IOException; +import java.nio.file.Files; +import java.nio.file.Paths; +import java.util.Properties; + +public class AuthConfig { + public static String DEFAULT_SECURITY_CONFIG = "config/client.security.config"; + private final Properties props; + + public AuthConfig(String securityConfigFile) throws IOException { + if (securityConfigFile == null) { + securityConfigFile = AuthConfig.DEFAULT_SECURITY_CONFIG; + } + props = new Properties(); + props.load(Files.newInputStream(Paths.get(securityConfigFile))); + } + + public String getKeystoreType() { + return props.getProperty("keystore.type"); + } + + public boolean wantClientAuth() { + return Boolean.valueOf(props.getProperty("want.client.auth")); + } + + public boolean needClientAuth() { + return Boolean.valueOf(props.getProperty("need.client.auth")); + } + + public String getKeystore() { + return props.getProperty("keystore"); + } + + public String getKeystorePassword() { + return props.getProperty("keystorePwd"); + } + + public String getKeyPassword() { + return props.getProperty("keyPwd"); + } + + public String getTruststore() { + return props.getProperty("truststore"); + } + + public String getTruststorePassword() { + return props.getProperty("truststorePwd"); + } +} diff --git a/clients/src/main/java/org/apache/kafka/common/network/security/KeyStores.java b/clients/src/main/java/org/apache/kafka/common/network/security/KeyStores.java new file mode 100644 index 0000000000000000000000000000000000000000..ae8d6cf9b9cbc2e441dabb7943abc606066e7469 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/network/security/KeyStores.java @@ -0,0 +1,20 @@ +package org.apache.kafka.common.network.security; + +import org.apache.kafka.common.errors.UnknownKeyStoreException; +import org.apache.kafka.common.network.security.store.JKSInitializer; + +/** + * Created by ilyutov on 6/5/14. + */ +public class KeyStores { + private KeyStores() { + } + + public static StoreInitializer getKeyStore(String name) throws UnknownKeyStoreException { + if (JKSInitializer.NAME.equals(name)) { + return JKSInitializer.getInstance(); + } else { + throw new UnknownKeyStoreException(String.format("%s is an unknown key store", name)); + } + } +} diff --git a/clients/src/main/java/org/apache/kafka/common/network/security/SecureAuth.java b/clients/src/main/java/org/apache/kafka/common/network/security/SecureAuth.java new file mode 100644 index 0000000000000000000000000000000000000000..f48b1aaf7ce4ed602237a6281fdf48a0d51e7b15 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/network/security/SecureAuth.java @@ -0,0 +1,40 @@ +package org.apache.kafka.common.network.security; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import javax.net.ssl.SSLContext; +import java.util.concurrent.atomic.AtomicBoolean; + +public class SecureAuth { + private static final Logger log = LoggerFactory.getLogger(SecureAuth.class); + private static final AtomicBoolean initialized = new AtomicBoolean(false); + private static SSLContext authContext; + + private SecureAuth() { + } + + public static SSLContext getSSLContext() { + if (!initialized.get()) { + throw new IllegalStateException("Secure authentication is not initialized."); + } + return authContext; + } + + public static void initialize(AuthConfig config) throws Exception { + if (initialized.get()) { + log.warn("Attempt to reinitialize auth context"); + return; + } + + log.info("Initializing secure authentication"); + + StoreInitializer initializer = KeyStores.getKeyStore(config.getKeystoreType()); + authContext = initializer.initialize(config); + + initialized.set(true); + + log.info("Secure authentication initialization has been successfully completed"); + } +} + diff --git a/clients/src/main/java/org/apache/kafka/common/network/security/StoreInitializer.java b/clients/src/main/java/org/apache/kafka/common/network/security/StoreInitializer.java new file mode 100644 index 0000000000000000000000000000000000000000..7a06b10df37a9c4f53779bf71d1b2f772d139360 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/network/security/StoreInitializer.java @@ -0,0 +1,16 @@ +package org.apache.kafka.common.network.security; + +import javax.net.ssl.SSLContext; +import java.io.IOException; +import java.security.KeyManagementException; +import java.security.KeyStoreException; +import java.security.NoSuchAlgorithmException; +import java.security.UnrecoverableKeyException; +import java.security.cert.CertificateException; + +/** + * Created by ilyutov on 6/5/14. + */ +public interface StoreInitializer { + SSLContext initialize(AuthConfig config) throws Exception; +} \ No newline at end of file diff --git a/clients/src/main/java/org/apache/kafka/common/network/security/store/JKSInitializer.java b/clients/src/main/java/org/apache/kafka/common/network/security/store/JKSInitializer.java new file mode 100644 index 0000000000000000000000000000000000000000..9e2ef0a05d9d283c3480504cf36098242ccb362c --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/network/security/store/JKSInitializer.java @@ -0,0 +1,87 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.common.network.security.store; + +import org.apache.kafka.common.network.security.AuthConfig; +import org.apache.kafka.common.network.security.StoreInitializer; + +import javax.net.ssl.*; +import java.io.FileInputStream; +import java.security.KeyManagementException; +import java.security.KeyStore; +import java.security.NoSuchAlgorithmException; + +public class JKSInitializer implements StoreInitializer { + public static final String NAME = "jks"; + private static JKSInitializer instance = null; + + private JKSInitializer() { + } + + public static JKSInitializer getInstance() { + if (instance == null) { + synchronized (JKSInitializer.class) { + if (instance == null) { + instance = new JKSInitializer(); + } + } + } + + return instance; + } + + public SSLContext initialize(AuthConfig config) throws Exception { + TrustManager[] trustManagers = getTrustManagers(config); + KeyManager[] keyManagers = getKeyManagers(config); + + return initContext(trustManagers, keyManagers); + } + + private TrustManager[] getTrustManagers(AuthConfig config) throws Exception { + KeyStore trustStore = KeyStore.getInstance("JKS"); + FileInputStream in = new FileInputStream(config.getTruststore()); + trustStore.load(in, config.getTruststorePassword().toCharArray()); + in.close(); + + TrustManagerFactory tmf = TrustManagerFactory.getInstance("SunX509"); + tmf.init(trustStore); + + return tmf.getTrustManagers(); + } + + private KeyManager[] getKeyManagers(AuthConfig config) throws Exception { + KeyStore keyStore = KeyStore.getInstance("JKS"); + FileInputStream in = new FileInputStream(config.getKeystore()); + keyStore.load(in, config.getKeystorePassword().toCharArray()); + in.close(); + + KeyManagerFactory kmf = KeyManagerFactory.getInstance("SunX509"); + kmf.init(keyStore, (config.getKeyPassword() != null) ? + config.getKeyPassword().toCharArray() : + config.getKeystorePassword().toCharArray()); + + return kmf.getKeyManagers(); + } + + private SSLContext initContext(TrustManager[] tms, KeyManager[] kms) throws KeyManagementException, NoSuchAlgorithmException { + SSLContext authContext = SSLContext.getInstance("TLS"); + authContext.init(kms, tms, null); + return authContext; + } +} + diff --git a/clients/src/test/java/org/apache/kafka/common/network/SelectorTest.java b/clients/src/test/java/org/apache/kafka/common/network/SelectorTest.java index 5c5e3d40819e41cab7b52a0eeaee5f2e7317b7b3..c511fa051f61af4e3b9befb874e0c5cf1655d917 100644 --- a/clients/src/test/java/org/apache/kafka/common/network/SelectorTest.java +++ b/clients/src/test/java/org/apache/kafka/common/network/SelectorTest.java @@ -50,7 +50,7 @@ public class SelectorTest { public void setup() throws Exception { this.server = new EchoServer(); this.server.start(); - this.selector = new Selector(new Metrics(), new MockTime()); + this.selector = new Selector(new Metrics(), new MockTime(), false); } @After diff --git a/config/client.keystore b/config/client.keystore new file mode 100644 index 0000000000000000000000000000000000000000..8ea0757e359f2c2af5bac9222f31eaf771d4f226 Binary files /dev/null and b/config/client.keystore differ diff --git a/config/client.public-key b/config/client.public-key new file mode 100644 index 0000000000000000000000000000000000000000..8b10172b3a7efc090fad06aa56278a94bb07be2b --- /dev/null +++ b/config/client.public-key @@ -0,0 +1,14 @@ +-----BEGIN CERTIFICATE----- +MIICeTCCAeKgAwIBAgIEUbWLWTANBgkqhkiG9w0BAQUFADCBgDELMAkGA1UEBhMCVVMxEzARBgNV +BAgTCkNhbGlmb3JuaWExFjAUBgNVBAcTDVNhbiBGcmFuY2lzY28xFzAVBgNVBAoTDlNhbGVzZm9y +Y2UuY29tMRQwEgYDVQQLEwtBcHBsaWNhdGlvbjEVMBMGA1UEAxMMS2Fma2EgQ2xpZW50MB4XDTEz +MDYxMDA4MTYyNVoXDTEzMDkwODA4MTYyNVowgYAxCzAJBgNVBAYTAlVTMRMwEQYDVQQIEwpDYWxp +Zm9ybmlhMRYwFAYDVQQHEw1TYW4gRnJhbmNpc2NvMRcwFQYDVQQKEw5TYWxlc2ZvcmNlLmNvbTEU +MBIGA1UECxMLQXBwbGljYXRpb24xFTATBgNVBAMTDEthZmthIENsaWVudDCBnzANBgkqhkiG9w0B +AQEFAAOBjQAwgYkCgYEAqtviDb8lrz+gfz91B1CXtaF3E0CRUh3YeHx1AwVqX8sXvTviAc6qM2Sv +Cpwi0x+dbq09uTrgo1NjAEc5ycnAgUXUi2/Jo6AmYz0MbRB7lX2hpc4drPDFknLlZCI1hgat42N4 +dq0L0fJ30VHnpvBErtnHij8SGBX55bKTu0PVPy8CAwEAATANBgkqhkiG9w0BAQUFAAOBgQBDaeOz +RbNdFy8d2eafnUEv2xt7/zTitTADrbs9RD3ZqTD8oziHNwQXFPuEtknx8myIZfephjjB0jHuGdrV +/xSAQwnufZQXJbfYpouKrF1mTT7Myn+kl6nGETu4tZyVH0MQqMFPhV5x6h5o9f/6Ei+DTK7VgdEv +f0HYifb50G6JMw== +-----END CERTIFICATE----- diff --git a/config/client.security.properties b/config/client.security.properties new file mode 100644 index 0000000000000000000000000000000000000000..de0b90a4ab97f26618c036413e43f1bdf8626889 --- /dev/null +++ b/config/client.security.properties @@ -0,0 +1,9 @@ +# Keystore file +keystore.type=jks +keystore=config/client.keystore +keystorePwd=test1234 +keyPwd=test1234 + +# Truststore file +truststore=config/client.keystore +truststorePwd=test1234 \ No newline at end of file diff --git a/config/consumer.properties b/config/consumer.properties index 83847de30d10b6e78bb8de28e0bb925d7c0e6ca2..81ca194c3e2110aa91bd485b415bebcd92463e85 100644 --- a/config/consumer.properties +++ b/config/consumer.properties @@ -27,3 +27,6 @@ group.id=test-consumer-group #consumer timeout #consumer.timeout.ms=5000 + +# Security config +security.config.file=config/client.security.properties diff --git a/config/producer.properties b/config/producer.properties index 39d65d7c6c21f4fccd7af89be6ca12a088d5dd98..bda5a8b1f6d3c0da6a8e53d1b67ef127b1f28a9a 100644 --- a/config/producer.properties +++ b/config/producer.properties @@ -36,6 +36,9 @@ serializer.class=kafka.serializer.DefaultEncoder # allow topic level compression #compressed.topics= +# Security config +security.config.file=config/client.security.properties + ############################# Async Producer ############################# # maximum time, in milliseconds, for buffering data on the producer queue #queue.buffering.max.ms= diff --git a/config/server.keystore b/config/server.keystore new file mode 100644 index 0000000000000000000000000000000000000000..fe0dd6a4803060a9c9be80b5d4649ca4e579c13b Binary files /dev/null and b/config/server.keystore differ diff --git a/config/server.properties b/config/server.properties index 5c0905a572b1f0d8b07bfca967a09cb856a6b09f..eb01e5b9881034e6b6b37e7f8745a27da4799841 100644 --- a/config/server.properties +++ b/config/server.properties @@ -36,6 +36,12 @@ port=9092 # it will publish the same port that the broker binds to. #advertised.port= +# SSL or plaintext +secure=true + +# Security config +security.config.file=config/server.security.properties + # The number of threads handling network requests num.network.threads=3 @@ -49,7 +55,7 @@ socket.send.buffer.bytes=102400 socket.receive.buffer.bytes=65536 # The maximum size of a request that the socket server will accept (protection against OOM) -socket.request.max.bytes=104857600 +socket.request.max.bytes=1048576000 ############################# Log Basics ############################# diff --git a/config/server.public-key b/config/server.public-key new file mode 100644 index 0000000000000000000000000000000000000000..abb5d6a3b4d319205309c1ec60f6a3980801a332 --- /dev/null +++ b/config/server.public-key @@ -0,0 +1,14 @@ +-----BEGIN CERTIFICATE----- +MIICeTCCAeKgAwIBAgIEUbWLGTANBgkqhkiG9w0BAQUFADCBgDELMAkGA1UEBhMCVVMxEzARBgNV +BAgTCkNhbGlmb3JuaWExFjAUBgNVBAcTDVNhbiBGcmFuY2lzY28xFzAVBgNVBAoTDlNhbGVzZm9y +Y2UuY29tMRQwEgYDVQQLEwtBcHBsaWNhdGlvbjEVMBMGA1UEAxMMS2Fma2EgQnJva2VyMB4XDTEz +MDYxMDA4MTUyMVoXDTEzMDkwODA4MTUyMVowgYAxCzAJBgNVBAYTAlVTMRMwEQYDVQQIEwpDYWxp +Zm9ybmlhMRYwFAYDVQQHEw1TYW4gRnJhbmNpc2NvMRcwFQYDVQQKEw5TYWxlc2ZvcmNlLmNvbTEU +MBIGA1UECxMLQXBwbGljYXRpb24xFTATBgNVBAMTDEthZmthIEJyb2tlcjCBnzANBgkqhkiG9w0B +AQEFAAOBjQAwgYkCgYEAnOJKErOQOfw9SbLpUbBtnU19Js1a0lBQ1oGz9UEPSgqaCjw7hyNipbP9 +FmwHyppD90ALZyggSjrOc+WG7RGA0YcHLCal/AmQc+aNDtfzoLAjJCFduBY6wgpL5dLUe/MYjndS +6uZyehzZyiWNgltRB9bRRHF/66RZ3m1jb7vnzd0CAwEAATANBgkqhkiG9w0BAQUFAAOBgQCOB6lT +7RKW7ktIyv0CCI0lreU4fXfLZCzCkMwu2xBbJfs3x67CgkoP+CJVRZN0xuBmJ3Yxr+s0XDHlgk9A +nL9mD3yc/wf4xS5meG5M+Ge8peLqORBMIJES18uvFaFkH6MyRIePfhM4lXKnjrTiM4u3tXZMYBbE +SwTotApOGMXVMA== +-----END CERTIFICATE----- diff --git a/config/server.security.properties b/config/server.security.properties new file mode 100644 index 0000000000000000000000000000000000000000..0bcbef1fa87769d42bd5cfea48d524792b11ef42 --- /dev/null +++ b/config/server.security.properties @@ -0,0 +1,33 @@ +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# see kafka.security.Config for more details + +#type of keystore +keystore.type=jks + +# Request client auth +want.client.auth=true + +# Require client auth +need.client.auth=true + +# Keystore file +keystore=config/server.keystore +keystorePwd=test1234 +keyPwd=test1234 + +# Truststore file +truststore=config/server.keystore +truststorePwd=test1234 diff --git a/core/src/main/scala/kafka/api/FetchRequest.scala b/core/src/main/scala/kafka/api/FetchRequest.scala index 51cdccf7f90eb530cc62b094ed822b8469d50b12..763ca43bc3c46d96a1299a65ffca4eb0d49c8ed0 100644 --- a/core/src/main/scala/kafka/api/FetchRequest.scala +++ b/core/src/main/scala/kafka/api/FetchRequest.scala @@ -177,6 +177,7 @@ class FetchRequestBuilder() { private var maxWait = FetchRequest.DefaultMaxWait private var minBytes = FetchRequest.DefaultMinBytes private val requestMap = new collection.mutable.HashMap[TopicAndPartition, PartitionFetchInfo] + private var secure = false def addFetch(topic: String, partition: Int, offset: Long, fetchSize: Int) = { requestMap.put(TopicAndPartition(topic, partition), PartitionFetchInfo(offset, fetchSize)) @@ -206,6 +207,11 @@ class FetchRequestBuilder() { this } + def secure(secure: Boolean): FetchRequestBuilder = { + this.secure = secure + this + } + def build() = { val fetchRequest = FetchRequest(versionId, correlationId.getAndIncrement, clientId, replicaId, maxWait, minBytes, requestMap.toMap) requestMap.clear() diff --git a/core/src/main/scala/kafka/client/ClientUtils.scala b/core/src/main/scala/kafka/client/ClientUtils.scala index ce7ede3f6d60e756e252257bd8c6fedc21f21e1c..80a0f103d9ce39a514c40bd088b479c23b347dc4 100644 --- a/core/src/main/scala/kafka/client/ClientUtils.scala +++ b/core/src/main/scala/kafka/client/ClientUtils.scala @@ -28,6 +28,7 @@ import util.Random import kafka.utils.ZkUtils._ import org.I0Itec.zkclient.ZkClient import java.io.IOException + import java.util.concurrent.TimeUnit /** * Helper functions common to clients (producer, consumer, or admin) @@ -82,12 +83,17 @@ object ClientUtils extends Logging{ * @param clientId The client's identifier * @return topic metadata response */ - def fetchTopicMetadata(topics: Set[String], brokers: Seq[Broker], clientId: String, timeoutMs: Int, + def fetchTopicMetadata(topics: Set[String], brokers: Seq[Broker], clientId: String, securityConfigFile: String, timeoutMs: Int, correlationId: Int = 0): TopicMetadataResponse = { val props = new Properties() props.put("metadata.broker.list", brokers.map(_.getConnectionString()).mkString(",")) props.put("client.id", clientId) props.put("request.timeout.ms", timeoutMs.toString) + if (securityConfigFile != null){ + props.put("secure", "true") + props.put("security.config.file", securityConfigFile) + } + val producerConfig = new ProducerConfig(props) fetchTopicMetadata(topics, brokers, producerConfig, correlationId) } @@ -104,7 +110,8 @@ object ClientUtils extends Logging{ val brokerInfos = brokerStr.split(":") val hostName = brokerInfos(0) val port = brokerInfos(1).toInt - new Broker(brokerId, hostName, port) + val secure = if (brokerInfos.length > 2) brokerInfos(2).toBoolean else false + new Broker(brokerId, hostName, port, secure) }) } @@ -119,7 +126,7 @@ object ClientUtils extends Logging{ Random.shuffle(allBrokers).find { broker => trace("Connecting to broker %s:%d.".format(broker.host, broker.port)) try { - channel = new BlockingChannel(broker.host, broker.port, BlockingChannel.UseDefaultBufferSize, BlockingChannel.UseDefaultBufferSize, socketTimeoutMs) + channel = new BlockingChannel(broker.host, broker.port, broker.secure, BlockingChannel.UseDefaultBufferSize, BlockingChannel.UseDefaultBufferSize, socketTimeoutMs) channel.connect() debug("Created channel to broker %s:%d.".format(channel.host, channel.port)) true @@ -181,7 +188,7 @@ object ClientUtils extends Logging{ var offsetManagerChannel: BlockingChannel = null try { debug("Connecting to offset manager %s.".format(connectString)) - offsetManagerChannel = new BlockingChannel(coordinator.host, coordinator.port, + offsetManagerChannel = new BlockingChannel(coordinator.host, coordinator.port, coordinator.secure, BlockingChannel.UseDefaultBufferSize, BlockingChannel.UseDefaultBufferSize, socketTimeoutMs) diff --git a/core/src/main/scala/kafka/cluster/Broker.scala b/core/src/main/scala/kafka/cluster/Broker.scala index 9407ed21fbbd57edeecd888edc32bea6a05d95b3..1008d328633ae75aed511a3fa841f5c087788019 100644 --- a/core/src/main/scala/kafka/cluster/Broker.scala +++ b/core/src/main/scala/kafka/cluster/Broker.scala @@ -37,7 +37,8 @@ private[kafka] object Broker { val brokerInfo = m.asInstanceOf[Map[String, Any]] val host = brokerInfo.get("host").get.asInstanceOf[String] val port = brokerInfo.get("port").get.asInstanceOf[Int] - new Broker(id, host, port) + val secure = if (brokerInfo.get("secure").get.asInstanceOf[Int] == 0) false else true + new Broker(id, host, port, secure) case None => throw new BrokerNotAvailableException("Broker id %d does not exist".format(id)) } @@ -50,32 +51,34 @@ private[kafka] object Broker { val id = buffer.getInt val host = readShortString(buffer) val port = buffer.getInt - new Broker(id, host, port) + val secure = buffer.getShort == 1 + new Broker(id, host, port, secure) } } -private[kafka] case class Broker(val id: Int, val host: String, val port: Int) { +private[kafka] case class Broker(val id: Int, val host: String, val port: Int, val secure: Boolean = false) { - override def toString(): String = new String("id:" + id + ",host:" + host + ",port:" + port) + override def toString(): String = new String("id:" + id + ",host:" + host + ",port:" + port + ",secure:" + secure) - def getConnectionString(): String = host + ":" + port + def getConnectionString(): String = host + ":" + port + ":" + (if (secure) 1 else 0) def writeTo(buffer: ByteBuffer) { buffer.putInt(id) writeShortString(buffer, host) buffer.putInt(port) + buffer.putShort(if (secure) 1 else 0) } - def sizeInBytes: Int = shortStringLength(host) /* host name */ + 4 /* port */ + 4 /* broker id*/ + def sizeInBytes: Int = shortStringLength(host) /* host name */ + 4 /* port */ + 4 /* broker id*/ + 2 /* secure */ override def equals(obj: Any): Boolean = { obj match { case null => false - case n: Broker => id == n.id && host == n.host && port == n.port + case n: Broker => id == n.id && host == n.host && port == n.port && secure == n.secure case _ => false } } - override def hashCode(): Int = hashcode(id, host, port) + override def hashCode(): Int = hashcode(id, host, port, if (secure) 1 else 0) } diff --git a/core/src/main/scala/kafka/common/UnknownKeyStoreException.scala b/core/src/main/scala/kafka/common/UnknownKeyStoreException.scala new file mode 100644 index 0000000000000000000000000000000000000000..f8796cb56928e19972161350caa57addfac74b30 --- /dev/null +++ b/core/src/main/scala/kafka/common/UnknownKeyStoreException.scala @@ -0,0 +1,22 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package kafka.common + +class UnknownKeyStoreException(message: String) extends RuntimeException(message) { + def this() = this(null) +} diff --git a/core/src/main/scala/kafka/consumer/ConsumerConfig.scala b/core/src/main/scala/kafka/consumer/ConsumerConfig.scala index 9ebbee6c16dc83767297c729d2d74ebbd063a993..5e6a3ca7ab142cf3540a9ddf8ab634bf8e592838 100644 --- a/core/src/main/scala/kafka/consumer/ConsumerConfig.scala +++ b/core/src/main/scala/kafka/consumer/ConsumerConfig.scala @@ -178,6 +178,9 @@ class ConsumerConfig private (val props: VerifiableProperties) extends ZKConfig( /** Whether messages from internal topics (such as offsets) should be exposed to the consumer. */ val excludeInternalTopics = props.getBoolean("exclude.internal.topics", ExcludeInternalTopics) + /** security config file */ + val securityConfigFile = props.getString("security.config.file", null) + /** Select a strategy for assigning partitions to consumer streams. Possible values: range, roundrobin */ val partitionAssignmentStrategy = props.getString("partition.assignment.strategy", DefaultPartitionAssignmentStrategy) diff --git a/core/src/main/scala/kafka/consumer/ConsumerFetcherManager.scala b/core/src/main/scala/kafka/consumer/ConsumerFetcherManager.scala index b9e2bea7b442a19bcebd1b350d39541a8c9dd068..b990f6e3b57f066e2361d51f7cd76e4156b357dd 100644 --- a/core/src/main/scala/kafka/consumer/ConsumerFetcherManager.scala +++ b/core/src/main/scala/kafka/consumer/ConsumerFetcherManager.scala @@ -66,6 +66,7 @@ class ConsumerFetcherManager(private val consumerIdString: String, val topicsMetadata = ClientUtils.fetchTopicMetadata(noLeaderPartitionSet.map(m => m.topic).toSet, brokers, config.clientId, + config.securityConfigFile, config.socketTimeoutMs, correlationId.getAndIncrement).topicsMetadata if(logger.isDebugEnabled) topicsMetadata.foreach(topicMetadata => debug(topicMetadata.toString())) diff --git a/core/src/main/scala/kafka/consumer/SimpleConsumer.scala b/core/src/main/scala/kafka/consumer/SimpleConsumer.scala index 8db9203d164a4a54f94d8d289e070a0f61e03ff9..b4e83eb37cd2da2fa39db59c9e3c2fb2b49c30bd 100644 --- a/core/src/main/scala/kafka/consumer/SimpleConsumer.scala +++ b/core/src/main/scala/kafka/consumer/SimpleConsumer.scala @@ -21,6 +21,7 @@ import kafka.api._ import kafka.network._ import kafka.utils._ import kafka.common.{ErrorMapping, TopicAndPartition} +import kafka.network.security.{AuthConfig, SecureAuth} /** * A consumer of kafka messages @@ -30,11 +31,22 @@ class SimpleConsumer(val host: String, val port: Int, val soTimeout: Int, val bufferSize: Int, - val clientId: String) extends Logging { + val clientId: String, + val secure: Boolean = false, + val securityConfigFile: String = null) extends Logging { ConsumerConfig.validateClientId(clientId) private val lock = new Object() - private val blockingChannel = new BlockingChannel(host, port, bufferSize, BlockingChannel.UseDefaultBufferSize, soTimeout) + + if (secure) { + synchronized { + if (!SecureAuth.isInitialized){ + SecureAuth.initialize(new AuthConfig(securityConfigFile)) + } + } + } + + private val blockingChannel = new BlockingChannel(host, port, secure, bufferSize, BlockingChannel.UseDefaultBufferSize, soTimeout) val brokerInfo = "host_%s-port_%s".format(host, port) private val fetchRequestAndResponseStats = FetchRequestAndResponseStatsRegistry.getFetchRequestAndResponseStats(clientId) private var isClosed = false diff --git a/core/src/main/scala/kafka/controller/ControllerChannelManager.scala b/core/src/main/scala/kafka/controller/ControllerChannelManager.scala index ecbfa0f328ba6a652a758ab20cacef324a8b2fb8..db25942acfa98dceccad945580dd829838dfcd30 100644 --- a/core/src/main/scala/kafka/controller/ControllerChannelManager.scala +++ b/core/src/main/scala/kafka/controller/ControllerChannelManager.scala @@ -80,7 +80,7 @@ class ControllerChannelManager (private val controllerContext: ControllerContext private def addNewBroker(broker: Broker) { val messageQueue = new LinkedBlockingQueue[(RequestOrResponse, (RequestOrResponse) => Unit)](config.controllerMessageQueueSize) debug("Controller %d trying to connect to broker %d".format(config.brokerId,broker.id)) - val channel = new BlockingChannel(broker.host, broker.port, + val channel = new BlockingChannel(broker.host, broker.port, broker.secure, BlockingChannel.UseDefaultBufferSize, BlockingChannel.UseDefaultBufferSize, config.controllerSocketTimeoutMs) diff --git a/core/src/main/scala/kafka/network/BlockingChannel.scala b/core/src/main/scala/kafka/network/BlockingChannel.scala index eb7bb14d94cb3648c06d4de36a3b34aacbde4556..93321a10a2ed7c8f6412d369fe6825c2735beac7 100644 --- a/core/src/main/scala/kafka/network/BlockingChannel.scala +++ b/core/src/main/scala/kafka/network/BlockingChannel.scala @@ -21,6 +21,7 @@ import java.net.InetSocketAddress import java.nio.channels._ import kafka.utils.{nonthreadsafe, Logging} import kafka.api.RequestOrResponse +import kafka.network.security.SSLSocketChannel object BlockingChannel{ @@ -33,7 +34,8 @@ object BlockingChannel{ */ @nonthreadsafe class BlockingChannel( val host: String, - val port: Int, + val port: Int, + val secure: Boolean, val readBufferSize: Int, val writeBufferSize: Int, val readTimeoutMs: Int ) extends Logging { @@ -42,52 +44,44 @@ class BlockingChannel( val host: String, private var readChannel: ReadableByteChannel = null private var writeChannel: GatheringByteChannel = null private val lock = new Object() - + def connect() = lock synchronized { if(!connected) { - try { - channel = SocketChannel.open() - if(readBufferSize > 0) - channel.socket.setReceiveBufferSize(readBufferSize) - if(writeBufferSize > 0) - channel.socket.setSendBufferSize(writeBufferSize) - channel.configureBlocking(true) - channel.socket.setSoTimeout(readTimeoutMs) - channel.socket.setKeepAlive(true) - channel.socket.setTcpNoDelay(true) - channel.connect(new InetSocketAddress(host, port)) + channel = if (secure) SSLSocketChannel.makeSecureClientConnection(SocketChannel.open(), host, port) else SocketChannel.open() + if(readBufferSize > 0) + channel.socket.setReceiveBufferSize(readBufferSize) + if(writeBufferSize > 0) + channel.socket.setSendBufferSize(writeBufferSize) + if (secure) channel.asInstanceOf[SSLSocketChannel].simulateBlocking(true) else channel.configureBlocking(true) + channel.socket.setSoTimeout(readTimeoutMs) + channel.socket.setKeepAlive(true) + channel.socket.setTcpNoDelay(true) + channel.connect(new InetSocketAddress(host, port)) - writeChannel = channel - readChannel = Channels.newChannel(channel.socket().getInputStream) - connected = true - // settings may not match what we requested above - val msg = "Created socket with SO_TIMEOUT = %d (requested %d), SO_RCVBUF = %d (requested %d), SO_SNDBUF = %d (requested %d)." - debug(msg.format(channel.socket.getSoTimeout, - readTimeoutMs, - channel.socket.getReceiveBufferSize, - readBufferSize, - channel.socket.getSendBufferSize, - writeBufferSize)) - } catch { - case e: Throwable => disconnect() - } + writeChannel = channel + readChannel = Channels.newChannel(Channels.newInputStream(channel)) + connected = true + val msg = "Created socket with SO_TIMEOUT = %d (requested %d), SO_RCVBUF = %d (requested %d), SO_SNDBUF = %d (requested %d)." + debug(msg.format(channel.socket.getSoTimeout, + readTimeoutMs, + channel.socket.getReceiveBufferSize, + readBufferSize, + channel.socket.getSendBufferSize, + writeBufferSize)) } } - + def disconnect() = lock synchronized { - if(channel != null) { + if(connected && channel != null) { + debug("Disconnecting channel " + channel.socket.getRemoteSocketAddress()) + // closing the main socket channel *should* close the read channel + // but let's do it to be sure. swallow(channel.close()) swallow(channel.socket.close()) - channel = null - writeChannel = null - } - // closing the main socket channel *should* close the read channel - // but let's do it to be sure. - if(readChannel != null) { swallow(readChannel.close()) - readChannel = null + channel = null; readChannel = null; writeChannel = null + connected = false } - connected = false } def isConnected = connected @@ -99,7 +93,7 @@ class BlockingChannel( val host: String, val send = new BoundedByteBufferSend(request) send.writeCompletely(writeChannel) } - + def receive(): Receive = { if(!connected) throw new ClosedChannelException() @@ -110,4 +104,4 @@ class BlockingChannel( val host: String, response } -} +} \ No newline at end of file diff --git a/core/src/main/scala/kafka/network/SocketServer.scala b/core/src/main/scala/kafka/network/SocketServer.scala index d67899080c21e0b6db84657d6845c7ef23b59b0e..1ddfd48f7425debe79a1c1885b2bca6981d6b86d 100644 --- a/core/src/main/scala/kafka/network/SocketServer.scala +++ b/core/src/main/scala/kafka/network/SocketServer.scala @@ -28,8 +28,10 @@ import scala.collection._ import kafka.common.KafkaException import kafka.metrics.KafkaMetricsGroup +import kafka.network.security._ import kafka.utils._ import com.yammer.metrics.core.{Gauge, Meter} +import javax.net.ssl.SSLException /** * An NIO socket server. The threading model is @@ -40,10 +42,12 @@ import com.yammer.metrics.core.{Gauge, Meter} class SocketServer(val brokerId: Int, val host: String, val port: Int, + val secure: Boolean, + val securityConfig: AuthConfig, val numProcessorThreads: Int, val maxQueuedRequests: Int, val sendBufferSize: Int, - val recvBufferSize: Int, + val receiveBufferSize: Int, val maxRequestSize: Int = Int.MaxValue, val maxConnectionsPerIp: Int = Int.MaxValue, val maxConnectionsPerIpOverrides: Map[String, Int] = Map[String, Int]()) extends Logging with KafkaMetricsGroup { @@ -60,6 +64,9 @@ class SocketServer(val brokerId: Int, * Start the socket server */ def startup() { + // If secure setup SSLContext + if (secure) SecureAuth.initialize(securityConfig) + val quotas = new ConnectionQuotas(maxConnectionsPerIp, maxConnectionsPerIpOverrides) for(i <- 0 until numProcessorThreads) { processors(i) = new Processor(i, @@ -69,7 +76,8 @@ class SocketServer(val brokerId: Int, newMeter("NetworkProcessor-" + i + "-IdlePercent", "percent", TimeUnit.NANOSECONDS), numProcessorThreads, requestChannel, - quotas) + quotas, + secure) Utils.newThread("kafka-network-thread-%d-%d".format(port, i), processors(i), false).start() } @@ -78,10 +86,10 @@ class SocketServer(val brokerId: Int, }) // register the processor threads for notification of responses - requestChannel.addResponseListener((id:Int) => processors(id).wakeup()) - + requestChannel.addResponseListener((id: Int) => processors(id).wakeup()) + // start accepting connections - this.acceptor = new Acceptor(host, port, processors, sendBufferSize, recvBufferSize, quotas) + this.acceptor = new Acceptor(host, port, secure, securityConfig, processors, sendBufferSize, receiveBufferSize, quotas) Utils.newThread("kafka-socket-acceptor", acceptor, false).start() acceptor.awaitStartup info("Started") @@ -92,9 +100,9 @@ class SocketServer(val brokerId: Int, */ def shutdown() = { info("Shutting down") - if(acceptor != null) + if (acceptor != null) acceptor.shutdown() - for(processor <- processors) + for (processor <- processors) processor.shutdown() info("Shutdown completed") } @@ -140,7 +148,7 @@ private[kafka] abstract class AbstractServerThread(connectionQuotas: ConnectionQ * Is the server still running? */ protected def isRunning = alive.get - + /** * Wakeup the thread for selection. */ @@ -187,16 +195,19 @@ private[kafka] abstract class AbstractServerThread(connectionQuotas: ConnectionQ } count } + } /** * Thread that accepts and configures new connections. There is only need for one of these */ private[kafka] class Acceptor(val host: String, - val port: Int, + val port: Int, + val secure: Boolean, + val securityConfig: AuthConfig, private val processors: Array[Processor], val sendBufferSize: Int, - val recvBufferSize: Int, + val receiveBufferSize: Int, connectionQuotas: ConnectionQuotas) extends AbstractServerThread(connectionQuotas) { val serverChannel = openServerSocket(host, port) @@ -204,23 +215,23 @@ private[kafka] class Acceptor(val host: String, * Accept loop that checks for new connection attempts */ def run() { - serverChannel.register(selector, SelectionKey.OP_ACCEPT); + serverChannel.register(selector, SelectionKey.OP_ACCEPT) startupComplete() var currentProcessor = 0 - while(isRunning) { + while (isRunning) { val ready = selector.select(500) - if(ready > 0) { + if (ready > 0) { val keys = selector.selectedKeys() val iter = keys.iterator() - while(iter.hasNext && isRunning) { + while (iter.hasNext && isRunning) { var key: SelectionKey = null try { key = iter.next iter.remove() - if(key.isAcceptable) - accept(key, processors(currentProcessor)) + if (key.isAcceptable) + accept(key, processors(currentProcessor)) else - throw new IllegalStateException("Unrecognized key state for acceptor thread.") + throw new IllegalStateException("Unrecognized key state for acceptor thread.") // round robin to the next processor thread currentProcessor = (currentProcessor + 1) % processors.length @@ -235,24 +246,23 @@ private[kafka] class Acceptor(val host: String, swallowError(selector.close()) shutdownComplete() } - + /* * Create a server socket to listen for connections on. */ def openServerSocket(host: String, port: Int): ServerSocketChannel = { - val socketAddress = - if(host == null || host.trim.isEmpty) + val socketAddress = + if (host == null || host.trim.isEmpty) new InetSocketAddress(port) else new InetSocketAddress(host, port) val serverChannel = ServerSocketChannel.open() serverChannel.configureBlocking(false) - serverChannel.socket().setReceiveBufferSize(recvBufferSize) try { serverChannel.socket.bind(socketAddress) info("Awaiting socket connections on %s:%d.".format(socketAddress.getHostName, port)) } catch { - case e: SocketException => + case e: SocketException => throw new KafkaException("Socket server failed to bind to %s:%d: %s.".format(socketAddress.getHostName, port, e.getMessage), e) } serverChannel @@ -263,7 +273,10 @@ private[kafka] class Acceptor(val host: String, */ def accept(key: SelectionKey, processor: Processor) { val serverSocketChannel = key.channel().asInstanceOf[ServerSocketChannel] - val socketChannel = serverSocketChannel.accept() + serverSocketChannel.socket().setReceiveBufferSize(receiveBufferSize) + + val sch = serverSocketChannel.accept() + val socketChannel = if (secure) SSLSocketChannel.makeSecureServerConnection(sch, securityConfig.wantClientAuth, securityConfig.needClientAuth) else sch try { connectionQuotas.inc(socketChannel.socket().getInetAddress) socketChannel.configureBlocking(false) @@ -273,18 +286,20 @@ private[kafka] class Acceptor(val host: String, debug("Accepted connection from %s on %s. sendBufferSize [actual|requested]: [%d|%d] recvBufferSize [actual|requested]: [%d|%d]" .format(socketChannel.socket.getInetAddress, socketChannel.socket.getLocalSocketAddress, socketChannel.socket.getSendBufferSize, sendBufferSize, - socketChannel.socket.getReceiveBufferSize, recvBufferSize)) + socketChannel.socket.getReceiveBufferSize)) processor.accept(socketChannel) } catch { case e: TooManyConnectionsException => info("Rejected connection from %s, address already has the configured maximum of %d connections.".format(e.ip, e.count)) - close(socketChannel) + socketChannel.close() } } } +private case class ChannelTuple(value: Any, sslChannel: SSLSocketChannel) + /** * Thread that processes all requests from a single connection. There are N of these running in parallel * each of which has its own selectors @@ -296,57 +311,66 @@ private[kafka] class Processor(val id: Int, val idleMeter: Meter, val totalProcessorThreads: Int, val requestChannel: RequestChannel, - connectionQuotas: ConnectionQuotas) extends AbstractServerThread(connectionQuotas) { + connectionQuotas: ConnectionQuotas, + val secure: Boolean) extends AbstractServerThread(connectionQuotas) { private val newConnections = new ConcurrentLinkedQueue[SocketChannel]() override def run() { startupComplete() - while(isRunning) { - // setup any new connections that have been queued up - configureNewConnections() - // register any new responses for writing - processNewResponses() - val startSelectTime = SystemTime.nanoseconds - val ready = selector.select(300) - val idleTime = SystemTime.nanoseconds - startSelectTime - idleMeter.mark(idleTime) - // We use a single meter for aggregate idle percentage for the thread pool. - // Since meter is calculated as total_recorded_value / time_window and - // time_window is independent of the number of threads, each recorded idle - // time should be discounted by # threads. - aggregateIdleMeter.mark(idleTime / totalProcessorThreads) - - trace("Processor id " + id + " selection time = " + idleTime + " ns") - if(ready > 0) { - val keys = selector.selectedKeys() - val iter = keys.iterator() - while(iter.hasNext && isRunning) { - var key: SelectionKey = null - try { - key = iter.next - iter.remove() - if(key.isReadable) - read(key) - else if(key.isWritable) - write(key) - else if(!key.isValid) - close(key) - else - throw new IllegalStateException("Unrecognized key state for processor thread.") - } catch { - case e: EOFException => { - info("Closing socket connection to %s.".format(channelFor(key).socket.getInetAddress)) - close(key) - } case e: InvalidRequestException => { - info("Closing socket connection to %s due to invalid request: %s".format(channelFor(key).socket.getInetAddress, e.getMessage)) - close(key) - } case e: Throwable => { - error("Closing socket for " + channelFor(key).socket.getInetAddress + " because of error", e) - close(key) + while (isRunning) { + try { + // setup any new connections that have been queued up + configureNewConnections() + // register any new responses for writing + processNewResponses() + val startSelectTime = SystemTime.nanoseconds + val ready = selector.select(300) + val idleTime = SystemTime.nanoseconds - startSelectTime + idleMeter.mark(idleTime) + // We use a single meter for aggregate idle percentage for the thread pool. + // Since meter is calculated as total_recorded_value / time_window and + // time_window is independent of the number of threads, each recorded idle + // time should be discounted by # threads. + aggregateIdleMeter.mark(idleTime / totalProcessorThreads) + + trace("Processor id " + id + " selection time = " + idleTime + " ns") + if (ready > 0) { + val keys = selector.selectedKeys() + val iter = keys.iterator() + while (iter.hasNext && isRunning) { + var key: SelectionKey = null + try { + key = iter.next + iter.remove() + if (key.isReadable) + read(key) + else if (key.isWritable) + write(key) + else if (!key.isValid) + close(key) + else + throw new IllegalStateException("Unrecognized key state for processor thread.") + } catch { + case e: EOFException => { + info("Closing socket connection to %s.".format(channelFor(key).socket.getInetAddress)) + close(key) + } + case e: InvalidRequestException => { + info("Closing socket connection to %s due to invalid request: %s".format(channelFor(key).socket.getInetAddress, e.getMessage)) + close(key) + } + case e: Throwable => { + error("Closing socket for " + channelFor(key).socket.getInetAddress + " because of error", e) + close(key) + } } } } + } catch { + case e: Throwable => { + error("Unexpected error", e) + } } } debug("Closing selector.") @@ -357,25 +381,27 @@ private[kafka] class Processor(val id: Int, private def processNewResponses() { var curr = requestChannel.receiveResponse(id) - while(curr != null) { + while (curr != null) { val key = curr.request.requestKey.asInstanceOf[SelectionKey] + val channelTuple = key.attachment.asInstanceOf[ChannelTuple] try { curr.responseAction match { case RequestChannel.NoOpAction => { // There is no response to send to the client, we need to read more pipelined requests // that are sitting in the server's socket buffer - curr.request.updateRequestMetrics + curr.request.updateRequestMetrics() trace("Socket server received empty response to send, registering for read: " + curr) key.interestOps(SelectionKey.OP_READ) - key.attach(null) + key.attach(ChannelTuple(null, channelTuple.sslChannel)) + readBufferedSSLDataIfNeeded(key, channelTuple) } case RequestChannel.SendAction => { trace("Socket server received response to send, registering for write: " + curr) key.interestOps(SelectionKey.OP_WRITE) - key.attach(curr) + key.attach(ChannelTuple(curr, channelTuple.sslChannel)) } case RequestChannel.CloseConnectionAction => { - curr.request.updateRequestMetrics + curr.request.updateRequestMetrics() trace("Closing socket connection actively according to the response code.") close(key) } @@ -392,6 +418,29 @@ private[kafka] class Processor(val id: Int, } } + override def close(key: SelectionKey) { + try { + val channel = channelFor(key) + debug("Closing connection from " + channel.socket.getRemoteSocketAddress) + swallowError(channel.close()) + swallowError(channel.socket().close()) + } finally { + key.attach(null) + swallowError(key.cancel()) + } + } + + /* + * Close all open connections + */ + override def closeAll() { + val iter = this.selector.keys().iterator() + while (iter.hasNext) { + val key = iter.next() + close(key) + } + } + /** * Queue up a new connection for reading */ @@ -404,10 +453,17 @@ private[kafka] class Processor(val id: Int, * Register any new connections that have been queued up */ private def configureNewConnections() { - while(newConnections.size() > 0) { + while (newConnections.size() > 0) { val channel = newConnections.poll() - debug("Processor " + id + " listening to new connection from " + channel.socket.getRemoteSocketAddress) - channel.register(selector, SelectionKey.OP_READ) + debug("Processor %s listening to new connection from %s".format(id, channel.socket.getRemoteSocketAddress)) + val (regChannel, sslsch) = channel match { + case sslsch: SSLSocketChannel => + val rch = sslsch.underlying.asInstanceOf[SocketChannel] + (rch, sslsch) + case _ => (channel, null) + } + val key = regChannel.register(selector, SelectionKey.OP_READ) + key.attach(ChannelTuple(null, sslsch)) } } @@ -415,26 +471,28 @@ private[kafka] class Processor(val id: Int, * Process reads from ready sockets */ def read(key: SelectionKey) { - val socketChannel = channelFor(key) - var receive = key.attachment.asInstanceOf[Receive] - if(key.attachment == null) { + val channelTuple = key.attachment.asInstanceOf[ChannelTuple] + val socketChannel = channelFor(key, SelectionKey.OP_READ) + if (socketChannel == null) return + var receive = channelTuple.value.asInstanceOf[Receive] + if (receive == null) { receive = new BoundedByteBufferReceive(maxRequestSize) - key.attach(receive) + key.attach(ChannelTuple(receive, channelTuple.sslChannel)) } val read = receive.readFrom(socketChannel) - val address = socketChannel.socket.getRemoteSocketAddress(); + val address = socketChannel.socket.getRemoteSocketAddress trace(read + " bytes read from " + address) - if(read < 0) { + if (read < 0) { close(key) - } else if(receive.complete) { + } else if (receive.complete) { val req = RequestChannel.Request(processor = id, requestKey = key, buffer = receive.buffer, startTimeMs = time.milliseconds, remoteAddress = address) requestChannel.sendRequest(req) - key.attach(null) + key.attach(ChannelTuple(null, channelTuple.sslChannel)) // explicitly reset interest ops to not READ, no need to wake up the selector just yet key.interestOps(key.interestOps & (~SelectionKey.OP_READ)) } else { // more reading to be done - trace("Did not finish reading, registering for read again on connection " + socketChannel.socket.getRemoteSocketAddress()) + trace("Did not finish reading, registering for read again on connection " + socketChannel.socket.getRemoteSocketAddress) key.interestOps(SelectionKey.OP_READ) wakeup() } @@ -444,27 +502,78 @@ private[kafka] class Processor(val id: Int, * Process writes to ready sockets */ def write(key: SelectionKey) { - val socketChannel = channelFor(key) - val response = key.attachment().asInstanceOf[RequestChannel.Response] + val channelTuple = key.attachment.asInstanceOf[ChannelTuple] + val socketChannel = channelFor(key, SelectionKey.OP_WRITE) + if (socketChannel == null) return + val response = channelTuple.value.asInstanceOf[RequestChannel.Response] val responseSend = response.responseSend - if(responseSend == null) + if (responseSend == null) throw new IllegalStateException("Registered for write interest but no response attached to key.") val written = responseSend.writeTo(socketChannel) - trace(written + " bytes written to " + socketChannel.socket.getRemoteSocketAddress() + " using key " + key) - if(responseSend.complete) { + trace(written + " bytes written to " + socketChannel.socket.getRemoteSocketAddress + " using key " + key) + if (responseSend.complete) { response.request.updateRequestMetrics() - key.attach(null) - trace("Finished writing, registering for read on connection " + socketChannel.socket.getRemoteSocketAddress()) + key.attach(ChannelTuple(null, channelTuple.sslChannel)) + trace("Finished writing, registering for read on connection " + socketChannel.socket.getRemoteSocketAddress) key.interestOps(SelectionKey.OP_READ) + readBufferedSSLDataIfNeeded(key, channelTuple) } else { - trace("Did not finish writing, registering for write again on connection " + socketChannel.socket.getRemoteSocketAddress()) + trace("Did not finish writing, registering for write again on connection " + socketChannel.socket.getRemoteSocketAddress) key.interestOps(SelectionKey.OP_WRITE) wakeup() } } - private def channelFor(key: SelectionKey) = key.channel().asInstanceOf[SocketChannel] + private def channelFor(key: SelectionKey, ops: Int = -1) = { + val sch = key.channel.asInstanceOf[SocketChannel] + if (secure) { + val secureSocketChannel = key.attachment.asInstanceOf[ChannelTuple].sslChannel + if (ops >= 0 && !secureSocketChannel.finished()) { + var done = false + try { + val next = secureSocketChannel.handshake(key.interestOps(), key) + if (next == 0) { + // when handshake is complete and we are doing a read so ahead with the read + // otherwise go back to read mode + if (ops == SelectionKey.OP_READ) { + done = true + } else { + key.interestOps(SelectionKey.OP_READ) + } + } else if (next != SSLSocketChannel.runningTasks) { + key.interestOps(next) + } + } catch { + case e: SSLException => // just ignore SSL disconnect errors + debug("SSLException: " + e) + close(key) + } + if (done) secureSocketChannel else null + } else secureSocketChannel + } else sch + } + private[this] def readBufferedSSLDataIfNeeded(key: SelectionKey, channelTuple: ChannelTuple) { + try { + if (channelTuple.sslChannel != null && channelTuple.sslChannel.isReadable) { + read(key) + } + } catch { + case e: EOFException => { + info("Closing socket connection to %s.".format(channelFor(key).socket.getInetAddress)) + close(key) + } + case e: InvalidRequestException => { + info("Closing socket connection to %s due to invalid request: %s".format(channelFor(key).socket.getInetAddress, + e.getMessage)) + close(key) + } + case e: Throwable => { + error("Closing socket for %s because of error".format(channelFor(key).socket.getInetAddress), e) + close(key) + } + } + } } class ConnectionQuotas(val defaultMax: Int, overrideQuotas: Map[String, Int]) { diff --git a/core/src/main/scala/kafka/network/security/AuthConfig.scala b/core/src/main/scala/kafka/network/security/AuthConfig.scala new file mode 100644 index 0000000000000000000000000000000000000000..37bb6e9aca2f497a513fa91c70df31cc5af9f949 --- /dev/null +++ b/core/src/main/scala/kafka/network/security/AuthConfig.scala @@ -0,0 +1,60 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package kafka.network.security + +import kafka.utils.VerifiableProperties +import kafka.utils.Utils +import kafka.utils.Logging + +object AuthConfig { + val DEFAULT_SECURITY_CONFIG = "config/client.security.properties" +} + +class AuthConfig(var securityConfigFile: String) extends Logging { + + val props = { + if (securityConfigFile == null) { + warn("securityConfigFile is null, using default securityConfigFile %s".format(AuthConfig.DEFAULT_SECURITY_CONFIG)) + securityConfigFile = AuthConfig.DEFAULT_SECURITY_CONFIG + } + new VerifiableProperties(Utils.loadProps(securityConfigFile)) + } + + val keystoreType = props.getString("keystore.type") + + /** Request client auth */ + val wantClientAuth = props.getBoolean("want.client.auth", false) + + /** Require client auth */ + val needClientAuth = props.getBoolean("need.client.auth", false) + + /** Keystore file location */ + val keystore = props.getString("keystore") + + /** Keystore file password */ + val keystorePwd = props.getString("keystorePwd") + + /** Keystore key password */ + val keyPwd = props.getString("keyPwd") + + /** Truststore file location */ + val truststore = props.getString("truststore") + + /** Truststore file password */ + val truststorePwd = props.getString("truststorePwd") +} diff --git a/core/src/main/scala/kafka/network/security/KeyStores.scala b/core/src/main/scala/kafka/network/security/KeyStores.scala new file mode 100644 index 0000000000000000000000000000000000000000..706209e37e43b2679845e1b13177ba32091e2596 --- /dev/null +++ b/core/src/main/scala/kafka/network/security/KeyStores.scala @@ -0,0 +1,35 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package kafka.network.security + +import kafka.network.security.store.JKSInitializer +import kafka.common.UnknownKeyStoreException +import javax.net.ssl.SSLContext + +trait StoreInitializer { + def initialize(config: AuthConfig): SSLContext +} + +object KeyStores { + def getKeyStore(name: String): StoreInitializer = { + name.toLowerCase match { + case JKSInitializer.name => JKSInitializer + case _ => throw new UnknownKeyStoreException("%s is an unknown key store".format(name)) + } + } +} diff --git a/core/src/main/scala/kafka/network/security/SSLSocketChannel.scala b/core/src/main/scala/kafka/network/security/SSLSocketChannel.scala new file mode 100644 index 0000000000000000000000000000000000000000..784252d1a12a3ec12dc87172d77e8ffcf755f4e7 --- /dev/null +++ b/core/src/main/scala/kafka/network/security/SSLSocketChannel.scala @@ -0,0 +1,657 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package kafka.network.security + +import java.io.IOException +import java.net._ +import java.nio.ByteBuffer +import java.nio.channels._ +import javax.net.ssl._ +import javax.net.ssl.SSLEngineResult._ +import java.util.concurrent._ +import java.util.concurrent.atomic.AtomicInteger +import kafka.utils.Logging +import java.util + +object SSLSocketChannel { + + def makeSecureClientConnection(sch: SocketChannel, host: String, port: Int) = { + val engine = SecureAuth.sslContext.createSSLEngine(host, port) + engine.setEnabledProtocols(Array("SSLv3")) + engine.setUseClientMode(true) + new SSLSocketChannel(sch, engine) + } + + def makeSecureServerConnection(socketChannel: SocketChannel, + wantClientAuth: Boolean = true, + needClientAuth: Boolean = true) = { + val engine = socketChannel.socket.getRemoteSocketAddress match { + case ise: InetSocketAddress => + SecureAuth.sslContext.createSSLEngine(ise.getHostName, ise.getPort) + case _ => + SecureAuth.sslContext.createSSLEngine() + } + engine.setEnabledProtocols(Array("SSLv3")) + engine.setUseClientMode(false) + if (wantClientAuth) { + engine.setWantClientAuth(true) + } + if (needClientAuth) { + engine.setNeedClientAuth(true) + } + new SSLSocketChannel(socketChannel, engine) + } + + val simulateSlowNetwork = false + + val runningTasks = -2 + private[this] lazy val counter = new AtomicInteger(0) + private[kafka] lazy val executor = new ThreadPoolExecutor(2, 10, + 60L, TimeUnit.SECONDS, + new SynchronousQueue[Runnable](), + new ThreadFactory() { + override def newThread(r: Runnable): Thread = { + val thread = new Thread(r, "SSLSession-Task-Thread-%d".format(counter.incrementAndGet())) + thread.setDaemon(true) + thread + } + }) +} + +class SSLSocketChannel(val underlying: SocketChannel, val sslEngine: SSLEngine) + extends SocketChannel(underlying.provider) with Logging { + + import SSLSocketChannel.executor + + private[this] class SSLTasker(val runnable: Runnable) extends Runnable { + selectionKey.interestOps(0) + + override def run(): Unit = { + try { + runnable.run() + outer.synchronized { + handshakeStatus = sslEngine.getHandshakeStatus + handshakeStatus match { + case HandshakeStatus.NEED_WRAP => + debug("sslTasker setting up to write for %s".format(underlying.socket.getRemoteSocketAddress)) + selectionKey.interestOps(SelectionKey.OP_WRITE) + case HandshakeStatus.NEED_UNWRAP => + if (peerNetData.position > 0) { + debug("sslTasker found existing data %s. running hanshake for %s".format(peerNetData, + underlying.socket.getRemoteSocketAddress)) + val init = outer.handshake(SelectionKey.OP_READ, selectionKey) + if (init == 0) { + debug("sslTasker setting up to read after hanshake") + selectionKey.interestOps(SelectionKey.OP_READ) + } else if (init != SSLSocketChannel.runningTasks) { + debug("sslTasker setting up for operation %d after hanshake for %s".format(init, + underlying.socket.getRemoteSocketAddress)) + selectionKey.interestOps(init) + } + } else { + debug("sslTasker setting up to read for %s".format(underlying.socket.getRemoteSocketAddress)) + selectionKey.interestOps(SelectionKey.OP_READ) + } + case HandshakeStatus.NEED_TASK => + val runnable = sslEngine.getDelegatedTask + if (runnable != null) { + debug("sslTasker running next task for %s".format(underlying.socket.getRemoteSocketAddress)) + executor.execute(new SSLTasker(runnable)) + handshakeStatus = null + } + return + case _ => + throw new SSLException("unexpected handshakeStatus: " + handshakeStatus) + } + selectionKey.selector.wakeup() + } + } catch { + case t: Throwable => + error("Unexpected exception", t) + } + } + } + + private[this] val outer = this + + /** + * The engine handshake status. + */ + private[this] var handshakeStatus: HandshakeStatus = HandshakeStatus.NOT_HANDSHAKING + + /** + * The initial handshake ops. + */ + @volatile private[this] var initialized = -1 + + /** + * Marker for shutdown status + */ + private[this] var shutdown = false + + private[this] var peerAppData = ByteBuffer.allocate(sslEngine.getSession.getApplicationBufferSize) + private[this] var myNetData = ByteBuffer.allocate(sslEngine.getSession.getPacketBufferSize) + private[this] var peerNetData = ByteBuffer.allocate(sslEngine.getSession.getPacketBufferSize) + private[this] val emptyBuffer = ByteBuffer.allocate(0) + + myNetData.limit(0) + + underlying.configureBlocking(false) + + private[this] var blocking = false + private[this] lazy val blockingSelector = Selector.open() + private[this] var blockingKey: SelectionKey = null + + @volatile private[this] var selectionKey: SelectionKey = null + + def simulateBlocking(b: Boolean) = { + blocking = b + } + + def socket(): Socket = underlying.socket + + def isConnected: Boolean = underlying.isConnected + + def isConnectionPending: Boolean = underlying.isConnectionPending + + def connect(remote: SocketAddress): Boolean = { + debug("SSLSocketChannel Connecting to Remote : " + remote) + val ret = underlying.connect(remote) + if (blocking) { + while (!finishConnect()) { + try { + Thread.sleep(10) + } catch { + case _: InterruptedException => + } + } + blockingKey = underlying.register(blockingSelector, SelectionKey.OP_READ) + handshakeInBlockMode(SelectionKey.OP_WRITE) + true + } else ret + } + + def finishConnect(): Boolean = underlying.finishConnect() + + def isReadable = finished && (peerAppData.position > 0 || peerNetData.position > 0) + + def read(dst: ByteBuffer): Int = { + this.synchronized { + if (peerAppData.position >= dst.remaining) { + return readFromPeerData(dst) + } else if (underlying.socket.isInputShutdown) { + throw new ClosedChannelException + } else if (initialized != 0) { + handshake(SelectionKey.OP_READ, selectionKey) + return 0 + } else if (shutdown) { + shutdown() + return -1 + } else if (sslEngine.isInboundDone) { + return -1 + } else { + val count = readRaw() + if (count <= 0 && peerNetData.position == 0) return count.asInstanceOf[Int] + } + + if (unwrap(false) < 0) return -1 + + readFromPeerData(dst) + } + } + + def read(destination: Array[ByteBuffer], offset: Int, length: Int): Long = { + var n = 0 + var i = offset + def localReadLoop() { + while (i < length) { + if (destination(i).hasRemaining) { + val x = read(destination(i)) + if (x > 0) { + n += x + if (!destination(i).hasRemaining) { + return + } + } else { + if ((x < 0) && (n == 0)) { + n = -1 + } + return + } + } + i = i + 1 + } + } + localReadLoop() + n + } + + def write(source: ByteBuffer): Int = { + this.synchronized { + if (myNetData.hasRemaining) { + writeRaw(myNetData) + return 0 + } else if (underlying.socket.isOutputShutdown) { + throw new ClosedChannelException + } else if (initialized != 0) { + handshake(SelectionKey.OP_WRITE, selectionKey) + return 0 + } else if (shutdown) { + shutdown() + return -1 + } + + val written = wrap(source) + + while (myNetData.hasRemaining) + writeRaw(myNetData) + written + } + } + + def write(sources: Array[ByteBuffer], offset: Int, length: Int): Long = { + var n = 0 + var i = offset + def localWriteLoop { + while (i < length) { + if (sources(i).hasRemaining) { + var x = write(sources(i)) + if (x > 0) { + n += x + if (!sources(i).hasRemaining) { + return + } + } else { + return + } + } + i = i + 1 + } + } + localWriteLoop + n + } + + def finished(): Boolean = initialized == 0 + + override def toString = "SSLSocketChannel[" + underlying.toString + "]" + + protected def implCloseSelectableChannel(): Unit = { + try { + _shutdown() + } catch { + case x: Exception => + } + underlying.close() + } + + protected def implConfigureBlocking(block: Boolean): Unit = { + simulateBlocking(block) + if (!block) underlying.configureBlocking(block) + } + + def handshake(o: Int, key: SelectionKey): Int = { + def writeIfReadyAndNeeded(mustWrite: Boolean): Boolean = { + if ((o & SelectionKey.OP_WRITE) != 0) { + writeRaw(myNetData) + myNetData.remaining > 0 + } else mustWrite + } + def readIfReadyAndNeeded(mustRead: Boolean): Boolean = { + if ((o & SelectionKey.OP_READ) != 0) { + if (readRaw() < 0) { + shutdown = true + underlying.close() + return true + } + val oldPos = peerNetData.position + unwrap(true) + oldPos == peerNetData.position + } else mustRead + } + + def localHandshake(): Int = { + while (true) { + handshakeStatus match { + case HandshakeStatus.NOT_HANDSHAKING => + info("begin ssl handshake for %s/%s".format(underlying.socket.getRemoteSocketAddress, + underlying.socket.getLocalSocketAddress)) + sslEngine.beginHandshake() + handshakeStatus = sslEngine.getHandshakeStatus + case HandshakeStatus.NEED_UNWRAP => + debug("need unwrap in ssl handshake for %s/%s".format(underlying.socket.getRemoteSocketAddress, + underlying.socket.getLocalSocketAddress)) + if (readIfReadyAndNeeded(true) && handshakeStatus != HandshakeStatus.FINISHED) { + debug("select to read more for %s/%s".format(underlying.socket.getRemoteSocketAddress, + underlying.socket.getLocalSocketAddress)) + return SelectionKey.OP_READ + } + case HandshakeStatus.NEED_WRAP => + debug("need wrap in ssl handshake for %s/%s".format(underlying.socket.getRemoteSocketAddress, + underlying.socket.getLocalSocketAddress)) + if (myNetData.remaining == 0) { + wrap(emptyBuffer) + } + if (writeIfReadyAndNeeded(true)) { + debug("select to write more for %s/%s".format(underlying.socket.getRemoteSocketAddress, + underlying.socket.getLocalSocketAddress)) + return SelectionKey.OP_WRITE + } + case HandshakeStatus.NEED_TASK => + handshakeStatus = runTasks() + case HandshakeStatus.FINISHED => + info("finished ssl handshake for %s/%s".format(underlying.socket.getRemoteSocketAddress, + underlying.socket.getLocalSocketAddress)) + return 0 + case null => + return SSLSocketChannel.runningTasks + } + } + o + } + + this.synchronized { + if (initialized == 0) return initialized + + if (selectionKey == null) selectionKey = key + + if (initialized != -1) { + if (writeIfReadyAndNeeded(false)) return o + } + val init = localHandshake() + if (init != SSLSocketChannel.runningTasks) { + initialized = init + } + init + } + } + + def shutdown() { + debug("SSLSocketChannel shutting down with locking") + this.synchronized(_shutdown()) + underlying.close() + } + + private def _shutdown() { + debug("SSLSocketChannel shutting down with out locking") + shutdown = true + + try { + if (!sslEngine.isOutboundDone) sslEngine.closeOutbound() + + myNetData.compact() + while (!sslEngine.isOutboundDone) { + val res = sslEngine.wrap(emptyBuffer, myNetData) + if (res.getStatus != Status.CLOSED) { + throw new SSLException("Unexpected shutdown status '%s'".format(res.getStatus)) + } + + myNetData.flip() + try { + while (myNetData.hasRemaining) + writeRaw(myNetData) + } catch { + case ignore: IOException => + } + } + } finally { + if (blockingKey != null) { + try { + blockingKey.cancel() + } finally { + blockingKey = null + blockingSelector.close() + } + } + } + } + + private def handshakeInBlockMode(ops: Int) = { + var o = ops + while (o != 0) { + val tops = handshake(o, null) + if (tops == o) { + try { + Thread.sleep(10) + } catch { + case _: InterruptedException => + } + } else { + o = tops + } + } + o + } + + private[this] def readRaw(): Long = { + def blockIfNeeded() { + if (blockingKey != null) { + try { + blockingSelector.select(5000) + } catch { + case t: Throwable => error("Unexpected error in blocking select", t) + } + } + } + this.synchronized { + blockIfNeeded() + try { + val n = underlying.read(peerNetData) + if (n < 0) { + sslEngine.closeInbound() + } + n + } catch { + case x: IOException => + sslEngine.closeInbound() + throw x + } + } + } + + private[this] def unwrap(isHandshaking: Boolean): Int = { + val pos = peerAppData.position + peerNetData.flip() + trace("unwrap: flipped peerNetData %s for %s/%s".format(peerNetData, + underlying.socket.getRemoteSocketAddress, + underlying.socket.getLocalSocketAddress)) + try { + while (peerNetData.hasRemaining) { + val result = sslEngine.unwrap(peerNetData, peerAppData) + handshakeStatus = result.getHandshakeStatus + result.getStatus match { + case SSLEngineResult.Status.OK => + if (handshakeStatus == HandshakeStatus.NEED_TASK) { + handshakeStatus = runTasks() + if (handshakeStatus == null) return 0 + } + if (isHandshaking && handshakeStatus == HandshakeStatus.FINISHED) { + return peerAppData.position - pos + } + case SSLEngineResult.Status.BUFFER_OVERFLOW => + peerAppData = expand(peerAppData, sslEngine.getSession.getApplicationBufferSize) + case SSLEngineResult.Status.BUFFER_UNDERFLOW => + return 0 + case SSLEngineResult.Status.CLOSED => + if (peerAppData.position == 0) { + trace("uwrap: shutdown for %s/%s".format(peerAppData, + underlying.socket.getRemoteSocketAddress, + underlying.socket.getLocalSocketAddress)) + shutdown() + return -1 + } else { + trace("uwrap: shutdown with non-empty peerAppData %s for %s/%s".format(peerAppData, + underlying.socket.getRemoteSocketAddress, + underlying.socket.getLocalSocketAddress)) + shutdown = true + return 0 + } + case _ => + throw new SSLException("Unexpected state!") + } + } + } finally { + peerNetData.compact() + trace("unwrap: compacted peerNetData %s for %s/%s".format(peerNetData, + underlying.socket.getRemoteSocketAddress, + underlying.socket.getLocalSocketAddress)) + } + peerAppData.position - pos + } + + private[this] def wrap(src: ByteBuffer): Int = { + val written = src.remaining + myNetData.compact() + trace("wrap: compacted myNetData %s for %s/%s".format(myNetData, + underlying.socket.getRemoteSocketAddress, + underlying.socket.getLocalSocketAddress)) + try { + do { + val result = sslEngine.wrap(src, myNetData) + handshakeStatus = result.getHandshakeStatus + result.getStatus match { + case SSLEngineResult.Status.OK => + if (handshakeStatus == HandshakeStatus.NEED_TASK) { + handshakeStatus = runTasks() + if (handshakeStatus == null) return 0 + } + case SSLEngineResult.Status.BUFFER_OVERFLOW => + val size = if (src.remaining * 2 > sslEngine.getSession.getApplicationBufferSize) src.remaining * 2 + else sslEngine.getSession.getApplicationBufferSize + myNetData = expand(myNetData, size) + case SSLEngineResult.Status.CLOSED => + shutdown() + throw new IOException("Write error received Status.CLOSED") + case _ => + throw new SSLException("Unexpected state!") + } + } while (src.hasRemaining) + } finally { + myNetData.flip() + trace("wrap: flipped myNetData %s for %s/%s".format(myNetData, + underlying.socket.getRemoteSocketAddress, + underlying.socket.getLocalSocketAddress)) + } + written + } + + private[this] def writeRaw(out: ByteBuffer): Long = { + def writeTwo(i: ByteBuffer): ByteBuffer = { + val o = ByteBuffer.allocate(2) + var rem = i.limit - i.position + if (rem > o.capacity) rem = o.capacity + var c = 0 + while (c < rem) { + o.put(i.get) + c += 1 + } + o.flip() + o + } + try { + if (out.hasRemaining) { + underlying.write(if (SSLSocketChannel.simulateSlowNetwork) writeTwo(out) else out) + } else 0 + } catch { + case x: IOException => + sslEngine.closeOutbound() + shutdown = true + throw x + } + } + + private[this] def runTasks(ops: Int = SelectionKey.OP_READ): HandshakeStatus = { + val reInitialize = initialized match { + case 0 => + initialized = ops + info("runTasks running renegotiation for %s/%s".format(underlying.socket.getRemoteSocketAddress, + underlying.socket.getLocalSocketAddress)) + true + case _ => false + } + var runnable: Runnable = sslEngine.getDelegatedTask + if (!blocking && selectionKey != null) { + debug("runTasks asynchronously in ssl handshake for %s/%s".format(underlying.socket.getRemoteSocketAddress, + underlying.socket.getLocalSocketAddress)) + if (runnable != null) { + executor.execute(new SSLTasker(runnable)) + } + null + } else { + debug("runTasks synchronously in ssl handshake for %s/%s".format(underlying.socket.getRemoteSocketAddress, + underlying.socket.getLocalSocketAddress)) + while (runnable != null) { + runnable.run() + runnable = sslEngine.getDelegatedTask + } + if (reInitialize) { + handshakeInBlockMode(ops) + } + sslEngine.getHandshakeStatus + } + } + + private[this] def expand(src: ByteBuffer, ensureSize: Int): ByteBuffer = { + if (src.remaining < ensureSize) { + val newBuffer = ByteBuffer.allocate(src.capacity + ensureSize) + if (src.position > 0) { + src.flip() + newBuffer.put(src) + } + newBuffer + } else { + src + } + } + + private[this] def readFromPeerData(dest: ByteBuffer): Int = { + peerAppData.flip() + try { + var remaining = peerAppData.remaining + if (remaining > 0) { + if (remaining > dest.remaining) { + remaining = dest.remaining + } + var i = 0 + while (i < remaining) { + dest.put(peerAppData.get) + i = i + 1 + } + } + remaining + } finally { + peerAppData.compact() + } + } + + def bind(local: SocketAddress): SocketChannel = underlying.bind(local) + + def shutdownInput(): SocketChannel = shutdownInput() + + def setOption[T](name: SocketOption[T], value: T): SocketChannel = underlying.setOption(name, value) + + def getRemoteAddress: SocketAddress = underlying.getRemoteAddress + + def shutdownOutput(): SocketChannel = underlying.shutdownOutput() + + def getLocalAddress: SocketAddress = underlying.getLocalAddress + + def getOption[T](name: SocketOption[T]): T = underlying.getOption(name) + + def supportedOptions(): util.Set[SocketOption[_]] = underlying.supportedOptions +} diff --git a/core/src/main/scala/kafka/network/security/SecureAuth.scala b/core/src/main/scala/kafka/network/security/SecureAuth.scala new file mode 100644 index 0000000000000000000000000000000000000000..908dc4fc1337b647769f37cde66fd28493d29f82 --- /dev/null +++ b/core/src/main/scala/kafka/network/security/SecureAuth.scala @@ -0,0 +1,52 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package kafka.network.security + +import javax.net.ssl.SSLContext +import kafka.utils.Logging +import java.util.concurrent.atomic.AtomicBoolean + +object SecureAuth extends Logging { + private val initialized = new AtomicBoolean(false) + private var authContext: SSLContext = null + + def isInitialized = initialized.get + + def sslContext = { + if (!initialized.get) { + throw new IllegalStateException("Secure authentication is not initialized.") + } + authContext + } + + def initialize(config: AuthConfig) { + if (initialized.get) { + warn("Attempt to reinitialize auth context") + return + } + + info("Initializing secure authentication") + + val initializer = KeyStores.getKeyStore(config.keystoreType) + authContext = initializer.initialize(config) + + initialized.set(true) + + info("Secure authentication initialization has been successfully completed") + } +} diff --git a/core/src/main/scala/kafka/network/security/store/JKSInitializer.scala b/core/src/main/scala/kafka/network/security/store/JKSInitializer.scala new file mode 100644 index 0000000000000000000000000000000000000000..5d02ba0460a7e914de35c2799f4cfbe16842679c --- /dev/null +++ b/core/src/main/scala/kafka/network/security/store/JKSInitializer.scala @@ -0,0 +1,61 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package kafka.network.security.store + +import java.io.FileInputStream +import javax.net.ssl._ +import kafka.network.security.{StoreInitializer, AuthConfig} + +object JKSInitializer extends StoreInitializer { + val name = "jks" + + def initialize(config: AuthConfig) = { + val tms = config.truststorePwd match { + case pw: String => + val ts = java.security.KeyStore.getInstance("JKS") + val fis: FileInputStream = new FileInputStream(config.truststore) + ts.load(fis, pw.toCharArray) + fis.close() + + val tmf = TrustManagerFactory.getInstance("SunX509") + tmf.init(ts) + tmf.getTrustManagers + case _ => null + } + val kms = config.keystorePwd match { + case pw: String => + val ks = java.security.KeyStore.getInstance("JKS") + val fis: FileInputStream = new FileInputStream(config.keystore) + ks.load(fis, pw.toCharArray) + fis.close() + + val kmf = KeyManagerFactory.getInstance("SunX509") + kmf.init(ks, if (config.keyPwd != null) config.keyPwd.toCharArray else pw.toCharArray) + kmf.getKeyManagers + case _ => null + } + + initContext(tms, kms) + } + + private def initContext(tms: Array[TrustManager], kms: Array[KeyManager]): SSLContext = { + val authContext = SSLContext.getInstance("TLS") + authContext.init(kms, tms, null) + authContext + } +} diff --git a/core/src/main/scala/kafka/producer/ProducerConfig.scala b/core/src/main/scala/kafka/producer/ProducerConfig.scala index 3cdf23dce3407f1770b9c6543e3a8ae8ab3ff255..b826a7724f5e24dd3b4b25d3b77660a7ffe015bd 100644 --- a/core/src/main/scala/kafka/producer/ProducerConfig.scala +++ b/core/src/main/scala/kafka/producer/ProducerConfig.scala @@ -22,6 +22,7 @@ import java.util.Properties import kafka.utils.{Utils, VerifiableProperties} import kafka.message.{CompressionCodec, NoCompressionCodec} import kafka.common.{InvalidConfigException, Config} +import kafka.network.security.AuthConfig object ProducerConfig extends Config { def validate(config: ProducerConfig) { @@ -113,5 +114,14 @@ class ProducerConfig private (val props: VerifiableProperties) */ val topicMetadataRefreshIntervalMs = props.getInt("topic.metadata.refresh.interval.ms", 600000) + /** determines whether use SSL or not */ + val secure = props.getBoolean("secure", false) + + /** security config */ + val securityConfig = if (secure) { + info("Secure sockets for data transfer is enabled"); + new AuthConfig(props.getString("security.config.file", null)) + } else null + validate(this) } diff --git a/core/src/main/scala/kafka/producer/ProducerPool.scala b/core/src/main/scala/kafka/producer/ProducerPool.scala index 43df70bb461dd3e385e6b20396adef3c4016a3fc..8996f6d693425bdef77105af3f5928685cd9279e 100644 --- a/core/src/main/scala/kafka/producer/ProducerPool.scala +++ b/core/src/main/scala/kafka/producer/ProducerPool.scala @@ -34,6 +34,7 @@ object ProducerPool { val props = new Properties() props.put("host", broker.host) props.put("port", broker.port.toString) + props.put("secure", broker.secure.toString) props.putAll(config.props.props) new SyncProducer(new SyncProducerConfig(props)) } diff --git a/core/src/main/scala/kafka/producer/SyncProducer.scala b/core/src/main/scala/kafka/producer/SyncProducer.scala index 489f0077512d9a69be81649c490274964290fa40..344ee9da7cb989f173507aad3cd636319f1553d9 100644 --- a/core/src/main/scala/kafka/producer/SyncProducer.scala +++ b/core/src/main/scala/kafka/producer/SyncProducer.scala @@ -21,6 +21,7 @@ import kafka.api._ import kafka.network.{BlockingChannel, BoundedByteBufferSend, Receive} import kafka.utils._ import java.util.Random +import kafka.network.security.SecureAuth object SyncProducer { val RequestKey: Short = 0 @@ -35,7 +36,11 @@ class SyncProducer(val config: SyncProducerConfig) extends Logging { private val lock = new Object() @volatile private var shutdown: Boolean = false - private val blockingChannel = new BlockingChannel(config.host, config.port, BlockingChannel.UseDefaultBufferSize, + + // If secure setup SSLContext + if (config.secure) SecureAuth.initialize(config.securityConfig) + + private val blockingChannel = new BlockingChannel(config.host, config.port, config.secure, BlockingChannel.UseDefaultBufferSize, config.sendBufferBytes, config.requestTimeoutMs) val brokerInfo = "host_%s-port_%s".format(config.host, config.port) val producerRequestStats = ProducerRequestStatsRegistry.getProducerRequestStats(config.clientId) @@ -126,8 +131,10 @@ class SyncProducer(val config: SyncProducerConfig) extends Logging { */ private def disconnect() { try { - info("Disconnecting from " + config.host + ":" + config.port) - blockingChannel.disconnect() + if(blockingChannel.isConnected) { + info("Disconnecting from " + config.host + ":" + config.port + ":" + config.secure) + blockingChannel.disconnect() + } } catch { case e: Exception => error("Error on disconnect: ", e) } @@ -137,11 +144,11 @@ class SyncProducer(val config: SyncProducerConfig) extends Logging { if (!blockingChannel.isConnected && !shutdown) { try { blockingChannel.connect() - info("Connected to " + config.host + ":" + config.port + " for producing") + info("Connected to " + config.host + ":" + config.port + ":" + config.secure + " for producing") } catch { case e: Exception => { disconnect() - error("Producer connection to " + config.host + ":" + config.port + " unsuccessful", e) + error("Producer connection to " + config.host + ":" + config.port + ":" + config.secure + " unsuccessful", e) throw e } } diff --git a/core/src/main/scala/kafka/producer/SyncProducerConfig.scala b/core/src/main/scala/kafka/producer/SyncProducerConfig.scala index 69b2d0c11bb1412ce76d566f285333c806be301a..5ff2cfb661a03c89dff011d45cd7c54524a12a49 100644 --- a/core/src/main/scala/kafka/producer/SyncProducerConfig.scala +++ b/core/src/main/scala/kafka/producer/SyncProducerConfig.scala @@ -18,9 +18,10 @@ package kafka.producer import java.util.Properties -import kafka.utils.VerifiableProperties +import kafka.utils._ +import kafka.network.security.AuthConfig -class SyncProducerConfig private (val props: VerifiableProperties) extends SyncProducerConfigShared { +class SyncProducerConfig private (val props: VerifiableProperties) extends SyncProducerConfigShared with Logging { def this(originalProps: Properties) { this(new VerifiableProperties(originalProps)) // no need to verify the property since SyncProducerConfig is supposed to be used internally @@ -31,6 +32,15 @@ class SyncProducerConfig private (val props: VerifiableProperties) extends SyncP /** the port on which the broker is running */ val port = props.getInt("port") + + /** determines whether use SSL or not */ + val secure = props.getBoolean("secure") + + /** security config */ + val securityConfig = if (secure) { + info("Secure sockets for data transfer is enabled"); + new AuthConfig(props.getString("security.config.file", null)) + } else null } trait SyncProducerConfigShared { diff --git a/core/src/main/scala/kafka/server/AbstractFetcherThread.scala b/core/src/main/scala/kafka/server/AbstractFetcherThread.scala index 2e9532e820b5b5c63dfd55f5454b32866d084a37..f9355c1294d4e5def9628a61d57d038a7ffa075b 100644 --- a/core/src/main/scala/kafka/server/AbstractFetcherThread.scala +++ b/core/src/main/scala/kafka/server/AbstractFetcherThread.scala @@ -45,7 +45,7 @@ abstract class AbstractFetcherThread(name: String, clientId: String, sourceBroke private val partitionMap = new mutable.HashMap[TopicAndPartition, Long] // a (topic, partition) -> offset map private val partitionMapLock = new ReentrantLock private val partitionMapCond = partitionMapLock.newCondition() - val simpleConsumer = new SimpleConsumer(sourceBroker.host, sourceBroker.port, socketTimeout, socketBufferSize, clientId) + val simpleConsumer = new SimpleConsumer(sourceBroker.host, sourceBroker.port, socketTimeout, socketBufferSize, clientId, sourceBroker.secure) private val brokerInfo = "host_%s-port_%s".format(sourceBroker.host, sourceBroker.port) private val metricId = new ClientIdAndBroker(clientId, brokerInfo) val fetcherStats = new FetcherStats(metricId) @@ -54,7 +54,8 @@ abstract class AbstractFetcherThread(name: String, clientId: String, sourceBroke clientId(clientId). replicaId(fetcherBrokerId). maxWait(maxWait). - minBytes(minBytes) + minBytes(minBytes). + secure(sourceBroker.secure) /* callbacks to be defined in subclass */ diff --git a/core/src/main/scala/kafka/server/KafkaConfig.scala b/core/src/main/scala/kafka/server/KafkaConfig.scala index dce48db175d6ea379f848a7768de0b1c8e4b929f..7de847e03c46e9b4fb39d7504ec87cb5e290cbe5 100644 --- a/core/src/main/scala/kafka/server/KafkaConfig.scala +++ b/core/src/main/scala/kafka/server/KafkaConfig.scala @@ -21,6 +21,7 @@ import java.util.Properties import kafka.message.{MessageSet, Message} import kafka.consumer.ConsumerConfig import kafka.utils.{VerifiableProperties, ZKConfig, Utils} +import kafka.network.security.AuthConfig /** * Configuration settings for the kafka server @@ -83,6 +84,15 @@ class KafkaConfig private (val props: VerifiableProperties) extends ZKConfig(pro /* the port to listen and accept connections on */ val port: Int = props.getInt("port", 6667) + /* is this running SSL */ + val secure: Boolean = props.getBoolean("secure", false) + + /* security config */ + val securityConfig = if (secure) + new AuthConfig(props.getString("security.config.file")) + else + null + /* hostname of broker. If this is set, it will only bind to this address. If this is not set, * it will bind to all interfaces */ val hostName: String = props.getString("host.name", null) diff --git a/core/src/main/scala/kafka/server/KafkaHealthcheck.scala b/core/src/main/scala/kafka/server/KafkaHealthcheck.scala index 4acdd70fe9c1ee78d6510741006c2ece65450671..cf6557b84a8ea4bbac9dc1cf6e008ac000818779 100644 --- a/core/src/main/scala/kafka/server/KafkaHealthcheck.scala +++ b/core/src/main/scala/kafka/server/KafkaHealthcheck.scala @@ -34,6 +34,7 @@ import java.net.InetAddress class KafkaHealthcheck(private val brokerId: Int, private val advertisedHost: String, private val advertisedPort: Int, + private val secure: Boolean, private val zkSessionTimeoutMs: Int, private val zkClient: ZkClient) extends Logging { @@ -60,7 +61,7 @@ class KafkaHealthcheck(private val brokerId: Int, else advertisedHost val jmxPort = System.getProperty("com.sun.management.jmxremote.port", "-1").toInt - ZkUtils.registerBrokerInZk(zkClient, brokerId, advertisedHostName, advertisedPort, zkSessionTimeoutMs, jmxPort) + ZkUtils.registerBrokerInZk(zkClient, brokerId, advertisedHostName, advertisedPort, zkSessionTimeoutMs, jmxPort, secure) } /** diff --git a/core/src/main/scala/kafka/server/KafkaServer.scala b/core/src/main/scala/kafka/server/KafkaServer.scala index 28711182aaa70eaa623de858bc063cb2613b2a4d..82362b4f8ae798ec05730eb81fa5715faaf5cd42 100644 --- a/core/src/main/scala/kafka/server/KafkaServer.scala +++ b/core/src/main/scala/kafka/server/KafkaServer.scala @@ -87,6 +87,8 @@ class KafkaServer(val config: KafkaConfig, time: Time = SystemTime) extends Logg socketServer = new SocketServer(config.brokerId, config.hostName, config.port, + config.secure, + config.securityConfig, config.numNetworkThreads, config.queuedMaxRequests, config.socketSendBufferBytes, @@ -101,31 +103,31 @@ class KafkaServer(val config: KafkaConfig, time: Time = SystemTime) extends Logg offsetManager = createOffsetManager() kafkaController = new KafkaController(config, zkClient, brokerState) - + /* start processing requests */ apis = new KafkaApis(socketServer.requestChannel, replicaManager, offsetManager, zkClient, config.brokerId, config, kafkaController) requestHandlerPool = new KafkaRequestHandlerPool(config.brokerId, socketServer.requestChannel, apis, config.numIoThreads) brokerState.newState(RunningAsBroker) - + Mx4jLoader.maybeLoad() replicaManager.startup() kafkaController.startup() - + topicConfigManager = new TopicConfigManager(zkClient, logManager) topicConfigManager.startup() - + /* tell everyone we are alive */ - kafkaHealthcheck = new KafkaHealthcheck(config.brokerId, config.advertisedHostName, config.advertisedPort, config.zkSessionTimeoutMs, zkClient) + kafkaHealthcheck = new KafkaHealthcheck(config.brokerId, config.advertisedHostName, config.advertisedPort, config.secure, config.zkSessionTimeoutMs, zkClient) kafkaHealthcheck.startup() - + registerStats() startupComplete.set(true) info("started") } - + private def initZk(): ZkClient = { info("Connecting to zookeeper on " + config.zkConnect) val zkClient = new ZkClient(config.zkConnect, config.zkSessionTimeoutMs, config.zkConnectionTimeoutMs, ZKStringSerializer) @@ -173,7 +175,7 @@ class KafkaServer(val config: KafkaConfig, time: Time = SystemTime) extends Logg if (channel != null) { channel.disconnect() } - channel = new BlockingChannel(broker.host, broker.port, + channel = new BlockingChannel(broker.host, broker.port, broker.secure, BlockingChannel.UseDefaultBufferSize, BlockingChannel.UseDefaultBufferSize, config.controllerSocketTimeoutMs) @@ -273,7 +275,7 @@ class KafkaServer(val config: KafkaConfig, time: Time = SystemTime) extends Logg def awaitShutdown(): Unit = shutdownLatch.await() def getLogManager(): LogManager = logManager - + private def createLogManager(zkClient: ZkClient, brokerState: BrokerState): LogManager = { val defaultLogConfig = LogConfig(segmentSize = config.logSegmentBytes, segmentMs = config.logRollTimeMillis, diff --git a/core/src/main/scala/kafka/tools/ConsoleConsumer.scala b/core/src/main/scala/kafka/tools/ConsoleConsumer.scala index 323fc8566d974acc4e5c7d7c2a065794f3b5df4a..5683852da39186f4ec24cdf577c38d35efdf657c 100644 --- a/core/src/main/scala/kafka/tools/ConsoleConsumer.scala +++ b/core/src/main/scala/kafka/tools/ConsoleConsumer.scala @@ -83,6 +83,10 @@ object ConsoleConsumer extends Logging { .withRequiredArg .describedAs("metrics dictory") .ofType(classOf[java.lang.String]) + val securityConfigFileOpt = parser.accepts("security.config.file", "Security config file to use for SSL.") + .withRequiredArg + .describedAs("property file") + .ofType(classOf[java.lang.String]) if(args.length == 0) CommandLineUtils.printUsageAndDie(parser, "The console consumer is a tool that reads data from Kafka and outputs it to standard output.") @@ -137,6 +141,9 @@ object ConsoleConsumer extends Logging { +". Please use --delete-consumer-offsets to delete previous offsets metadata") System.exit(1) } + if (options.has(securityConfigFileOpt)) { + consumerProps.put("security.config.file", options.valueOf(securityConfigFileOpt)) + } if(options.has(deleteConsumerOffsetsOpt)) ZkUtils.maybeDeletePath(options.valueOf(zkConnectOpt), "/consumers/" + consumerProps.getProperty("group.id")) diff --git a/core/src/main/scala/kafka/tools/ConsoleProducer.scala b/core/src/main/scala/kafka/tools/ConsoleProducer.scala index da4dad405c8d8f26a64cda78a292e1f5bfbdcc22..21fd10207c5891003e3a111885e5de741ee08d71 100644 --- a/core/src/main/scala/kafka/tools/ConsoleProducer.scala +++ b/core/src/main/scala/kafka/tools/ConsoleProducer.scala @@ -27,6 +27,7 @@ import java.util.Properties import java.io._ import joptsimple._ +import kafka.network.security.AuthConfig object ConsoleProducer { @@ -59,6 +60,8 @@ object ConsoleProducer { props.put(ProducerConfig.BUFFER_MEMORY_CONFIG, config.maxMemoryBytes.toString) props.put(ProducerConfig.BATCH_SIZE_CONFIG, config.maxPartitionMemoryBytes.toString) props.put(ProducerConfig.CLIENT_ID_CONFIG, "console-producer") + props.put(ProducerConfig.SECURE, config.secure.toString) + if (config.securityConfig != null) props.put(ProducerConfig.SECURITY_CONFIG_FILE, config.securityConfig) new NewShinyProducer(props) } else { @@ -78,6 +81,8 @@ object ConsoleProducer { props.put("send.buffer.bytes", config.socketBuffer.toString) props.put("topic.metadata.refresh.interval.ms", config.metadataExpiryMs.toString) props.put("client.id", "console-producer") + props.put("secure", config.secure.toString) + if (config.securityConfig != null) props.put("security.config.file", config.securityConfig) new OldProducer(props) } @@ -210,6 +215,11 @@ object ConsoleProducer { .describedAs("prop") .ofType(classOf[String]) val useNewProducerOpt = parser.accepts("new-producer", "Use the new producer implementation.") + val secureOpt = parser.accepts("secure", "Whether SSL enabled").withOptionalArg() + val securityConfigFileOpt = parser.accepts("client.security.file", "Security config file to use for SSL.") + .withRequiredArg + .describedAs("property file") + .ofType(classOf[java.lang.String]) val options = parser.parse(args : _*) if(args.length == 0) @@ -245,6 +255,9 @@ object ConsoleProducer { val maxPartitionMemoryBytes = options.valueOf(maxPartitionMemoryBytesOpt) val metadataExpiryMs = options.valueOf(metadataExpiryMsOpt) val metadataFetchTimeoutMs = options.valueOf(metadataFetchTimeoutMsOpt) + + val secure = options.has(secureOpt) + val securityConfig = options.valueOf(securityConfigFileOpt) } trait MessageReader { diff --git a/core/src/main/scala/kafka/tools/ConsumerOffsetChecker.scala b/core/src/main/scala/kafka/tools/ConsumerOffsetChecker.scala index d1e7c434e77859d746b8dc68dd5d5a3740425e79..b238b385245a01ed73fb5eeda626f08fd5551625 100644 --- a/core/src/main/scala/kafka/tools/ConsumerOffsetChecker.scala +++ b/core/src/main/scala/kafka/tools/ConsumerOffsetChecker.scala @@ -37,7 +37,7 @@ object ConsumerOffsetChecker extends Logging { private val offsetMap: mutable.Map[TopicAndPartition, Long] = mutable.Map() private var topicPidMap: immutable.Map[String, Seq[Int]] = immutable.Map() - private def getConsumer(zkClient: ZkClient, bid: Int): Option[SimpleConsumer] = { + private def getConsumer(zkClient: ZkClient, bid: Int, securityConfigFile : String): Option[SimpleConsumer] = { try { ZkUtils.readDataMaybeNull(zkClient, ZkUtils.BrokerIdsPath + "/" + bid)._1 match { case Some(brokerInfoString) => @@ -46,7 +46,8 @@ object ConsumerOffsetChecker extends Logging { val brokerInfo = m.asInstanceOf[Map[String, Any]] val host = brokerInfo.get("host").get.asInstanceOf[String] val port = brokerInfo.get("port").get.asInstanceOf[Int] - Some(new SimpleConsumer(host, port, 10000, 100000, "ConsumerOffsetChecker")) + val secure = brokerInfo.get("secure").get.asInstanceOf[Boolean] + Some(new SimpleConsumer(host, port, 10000, 100000, "ConsumerOffsetChecker", secure, securityConfigFile)) case None => throw new BrokerNotAvailableException("Broker id %d does not exist".format(bid)) } @@ -61,14 +62,14 @@ object ConsumerOffsetChecker extends Logging { } private def processPartition(zkClient: ZkClient, - group: String, topic: String, pid: Int) { + group: String, topic: String, pid: Int, securityConfigFile : String) { val topicPartition = TopicAndPartition(topic, pid) val offsetOpt = offsetMap.get(topicPartition) val groupDirs = new ZKGroupTopicDirs(group, topic) val owner = ZkUtils.readDataMaybeNull(zkClient, groupDirs.consumerOwnerDir + "/%s".format(pid))._1 ZkUtils.getLeaderForPartition(zkClient, topic, pid) match { case Some(bid) => - val consumerOpt = consumerMap.getOrElseUpdate(bid, getConsumer(zkClient, bid)) + val consumerOpt = consumerMap.getOrElseUpdate(bid, getConsumer(zkClient, bid, securityConfigFile)) consumerOpt match { case Some(consumer) => val topicAndPartition = TopicAndPartition(topic, pid) @@ -86,11 +87,11 @@ object ConsumerOffsetChecker extends Logging { } } - private def processTopic(zkClient: ZkClient, group: String, topic: String) { + private def processTopic(zkClient: ZkClient, group: String, topic: String, securityConfigFile : String) { topicPidMap.get(topic) match { case Some(pids) => pids.sorted.foreach { - pid => processPartition(zkClient, group, topic, pid) + pid => processPartition(zkClient, group, topic, pid, securityConfigFile) } case None => // ignore } @@ -123,17 +124,33 @@ object ConsumerOffsetChecker extends Logging { parser.accepts("broker-info", "Print broker info") parser.accepts("help", "Print this message.") + val securityConfigFileOpt = parser.accepts("security.config.file", "Security config file to use for SSL.") + .withRequiredArg + .describedAs("property file") + .ofType(classOf[java.lang.String]) if(args.length == 0) CommandLineUtils.printUsageAndDie(parser, "Check the offset of your consumers.") val options = parser.parse(args : _*) + var securityConfigFile:String = null; if (options.has("help")) { parser.printHelpOn(System.out) System.exit(0) } + if (options.has(securityConfigFileOpt)){ + securityConfigFile = options.valueOf(securityConfigFileOpt); + } + + for (opt <- List(groupOpt, zkConnectOpt)) + if (!options.has(opt)) { + System.err.println("Missing required argument: %s".format(opt)) + parser.printHelpOn(System.err) + System.exit(1) + } + CommandLineUtils.checkRequiredArgs(parser, options, groupOpt, zkConnectOpt) val zkConnect = options.valueOf(zkConnectOpt) @@ -191,7 +208,7 @@ object ConsumerOffsetChecker extends Logging { println("%-15s %-30s %-3s %-15s %-15s %-15s %s".format("Group", "Topic", "Pid", "Offset", "logSize", "Lag", "Owner")) topicList.sorted.foreach { - topic => processTopic(zkClient, group, topic) + topic => processTopic(zkClient, group, topic, securityConfigFile) } if (options.has("broker-info")) diff --git a/core/src/main/scala/kafka/tools/GetOffsetShell.scala b/core/src/main/scala/kafka/tools/GetOffsetShell.scala index 9c6064e201eebbcd5b276a0dedd02937439edc94..0ec94c91b53f3d7bf75dbf046767cc0b637db66b 100644 --- a/core/src/main/scala/kafka/tools/GetOffsetShell.scala +++ b/core/src/main/scala/kafka/tools/GetOffsetShell.scala @@ -57,6 +57,10 @@ object GetOffsetShell { .describedAs("ms") .ofType(classOf[java.lang.Integer]) .defaultsTo(1000) + val securityConfigFileOpt = parser.accepts("security.config.file", "Security config file to use for SSL.") + .withRequiredArg + .describedAs("property file") + .ofType(classOf[java.lang.String]) if(args.length == 0) CommandLineUtils.printUsageAndDie(parser, "An interactive shell for getting consumer offsets.") @@ -72,8 +76,9 @@ object GetOffsetShell { var time = options.valueOf(timeOpt).longValue val nOffsets = options.valueOf(nOffsetsOpt).intValue val maxWaitMs = options.valueOf(maxWaitMsOpt).intValue() + val secure = options.has(securityConfigFileOpt) - val topicsMetadata = ClientUtils.fetchTopicMetadata(Set(topic), metadataTargetBrokers, clientId, maxWaitMs).topicsMetadata + val topicsMetadata = ClientUtils.fetchTopicMetadata(Set(topic), metadataTargetBrokers, clientId, options.valueOf(securityConfigFileOpt), maxWaitMs).topicsMetadata if(topicsMetadata.size != 1 || !topicsMetadata(0).topic.equals(topic)) { System.err.println(("Error: no valid topic metadata for topic: %s, " + " probably the topic does not exist, run ").format(topic) + "kafka-list-topic.sh to verify") @@ -91,7 +96,7 @@ object GetOffsetShell { case Some(metadata) => metadata.leader match { case Some(leader) => - val consumer = new SimpleConsumer(leader.host, leader.port, 10000, 100000, clientId) + val consumer = new SimpleConsumer(leader.host, leader.port, 10000, 100000, clientId, secure, options.valueOf(securityConfigFileOpt)) val topicAndPartition = TopicAndPartition(topic, partitionId) val request = OffsetRequest(Map(topicAndPartition -> PartitionOffsetRequestInfo(time, nOffsets))) val offsets = consumer.getOffsetsBefore(request).partitionErrorAndOffsets(topicAndPartition).offsets diff --git a/core/src/main/scala/kafka/tools/ReplicaVerificationTool.scala b/core/src/main/scala/kafka/tools/ReplicaVerificationTool.scala index af4783646803e58714770c21f8c3352370f26854..d35c3c005c7d7474236c8e473c50f26648549773 100644 --- a/core/src/main/scala/kafka/tools/ReplicaVerificationTool.scala +++ b/core/src/main/scala/kafka/tools/ReplicaVerificationTool.scala @@ -92,6 +92,10 @@ object ReplicaVerificationTool extends Logging { .describedAs("ms") .ofType(classOf[java.lang.Long]) .defaultsTo(30 * 1000L) + val securityConfigFileOpt = parser.accepts("security.config.file", "Security config file to use for SSL.") + .withRequiredArg + .describedAs("property file") + .ofType(classOf[java.lang.String]) if(args.length == 0) CommandLineUtils.printUsageAndDie(parser, "Validate that all replicas for a set of topics have the same data.") @@ -114,10 +118,11 @@ object ReplicaVerificationTool extends Logging { val maxWaitMs = options.valueOf(maxWaitMsOpt).intValue val initialOffsetTime = options.valueOf(initialOffsetTimeOpt).longValue val reportInterval = options.valueOf(reportIntervalOpt).longValue + val securityConfigFile = options.valueOf(securityConfigFileOpt) // getting topic metadata info("Getting topic metatdata...") val metadataTargetBrokers = ClientUtils.parseBrokerList(options.valueOf(brokerListOpt)) - val topicsMetadataResponse = ClientUtils.fetchTopicMetadata(Set[String](), metadataTargetBrokers, clientId, maxWaitMs) + val topicsMetadataResponse = ClientUtils.fetchTopicMetadata(Set[String](), metadataTargetBrokers, clientId, securityConfigFile, maxWaitMs) val brokerMap = topicsMetadataResponse.brokers.map(b => (b.id, b)).toMap val filteredTopicMetadata = topicsMetadataResponse.topicsMetadata.filter( topicMetadata => if (topicWhiteListFiler.isTopicAllowed(topicMetadata.topic, excludeInternalTopics = false)) diff --git a/core/src/main/scala/kafka/tools/SimpleConsumerShell.scala b/core/src/main/scala/kafka/tools/SimpleConsumerShell.scala index 36314f412a8281aece2789fd2b74a106b82c57d2..24c60736b7ad20a4a99299fd9973571c7add42d4 100644 --- a/core/src/main/scala/kafka/tools/SimpleConsumerShell.scala +++ b/core/src/main/scala/kafka/tools/SimpleConsumerShell.scala @@ -96,6 +96,10 @@ object SimpleConsumerShell extends Logging { if(args.length == 0) CommandLineUtils.printUsageAndDie(parser, "A low-level tool for fetching data directly from a particular replica.") + val securityConfigFileOpt = parser.accepts("security.config.file", "Security config file to use for SSL.") + .withRequiredArg + .describedAs("property file") + .ofType(classOf[java.lang.String]) val options = parser.parse(args : _*) CommandLineUtils.checkRequiredArgs(parser, options, brokerListOpt, topicOpt, partitionIdOpt) @@ -112,6 +116,7 @@ object SimpleConsumerShell extends Logging { val skipMessageOnError = if (options.has(skipMessageOnErrorOpt)) true else false val printOffsets = if(options.has(printOffsetOpt)) true else false val noWaitAtEndOfLog = options.has(noWaitAtEndOfLogOpt) + val securityConfigFile = options.valueOf(securityConfigFileOpt) val messageFormatterClass = Class.forName(options.valueOf(messageFormatterOpt)) val formatterArgs = CommandLineUtils.parseKeyValueArgs(options.valuesOf(messageFormatterArgOpt)) @@ -125,7 +130,7 @@ object SimpleConsumerShell extends Logging { // getting topic metadata info("Getting topic metatdata...") val metadataTargetBrokers = ClientUtils.parseBrokerList(options.valueOf(brokerListOpt)) - val topicsMetadata = ClientUtils.fetchTopicMetadata(Set(topic), metadataTargetBrokers, clientId, maxWaitMs).topicsMetadata + val topicsMetadata = ClientUtils.fetchTopicMetadata(Set(topic), metadataTargetBrokers, clientId, securityConfigFile, maxWaitMs).topicsMetadata if(topicsMetadata.size != 1 || !topicsMetadata(0).topic.equals(topic)) { System.err.println(("Error: no valid topic metadata for topic: %s, " + "what we get from server is only: %s").format(topic, topicsMetadata)) System.exit(1) diff --git a/core/src/main/scala/kafka/utils/ZkUtils.scala b/core/src/main/scala/kafka/utils/ZkUtils.scala index a7b1fdcb50d5cf930352d37e39cb4fc9a080cb12..a34ba2c079d88fe9fa2544cbdc5289b11b7fd2e7 100644 --- a/core/src/main/scala/kafka/utils/ZkUtils.scala +++ b/core/src/main/scala/kafka/utils/ZkUtils.scala @@ -158,11 +158,12 @@ object ZkUtils extends Logging { } } - def registerBrokerInZk(zkClient: ZkClient, id: Int, host: String, port: Int, timeout: Int, jmxPort: Int) { + def registerBrokerInZk(zkClient: ZkClient, id: Int, host: String, port: Int, timeout: Int, jmxPort: Int, secure: Boolean) { val brokerIdPath = ZkUtils.BrokerIdsPath + "/" + id val timestamp = SystemTime.milliseconds.toString - val brokerInfo = Json.encode(Map("version" -> 1, "host" -> host, "port" -> port, "jmx_port" -> jmxPort, "timestamp" -> timestamp)) - val expectedBroker = new Broker(id, host, port) + val secureValue = if (secure) 1 else 0 + val brokerInfo = Json.encode(Map("version" -> 1, "host" -> host, "port" -> port, "jmx_port" -> jmxPort, "timestamp" -> timestamp, "secure" -> secureValue)) + val expectedBroker = new Broker(id, host, port, secure) try { createEphemeralPathExpectConflictHandleZKBug(zkClient, brokerIdPath, brokerInfo, expectedBroker, diff --git a/core/src/test/scala/unit/kafka/admin/AddPartitionsTest.scala b/core/src/test/scala/unit/kafka/admin/AddPartitionsTest.scala index 1bf2667f47853585bc33ffb3e28256ec5f24ae84..aad0ec1f37b479c0ec76cc7abedd3adc7799e393 100644 --- a/core/src/test/scala/unit/kafka/admin/AddPartitionsTest.scala +++ b/core/src/test/scala/unit/kafka/admin/AddPartitionsTest.scala @@ -110,7 +110,7 @@ class AddPartitionsTest extends JUnit3Suite with ZooKeeperTestHarness { TestUtils.waitUntilMetadataIsPropagated(servers, topic1, 1) TestUtils.waitUntilMetadataIsPropagated(servers, topic1, 2) val metadata = ClientUtils.fetchTopicMetadata(Set(topic1), brokers, "AddPartitionsTest-testIncrementPartitions", - 2000,0).topicsMetadata + null, 2000,0).topicsMetadata val metaDataForTopic1 = metadata.filter(p => p.topic.equals(topic1)) val partitionDataForTopic1 = metaDataForTopic1.head.partitionsMetadata assertEquals(partitionDataForTopic1.size, 3) @@ -135,7 +135,7 @@ class AddPartitionsTest extends JUnit3Suite with ZooKeeperTestHarness { TestUtils.waitUntilMetadataIsPropagated(servers, topic2, 1) TestUtils.waitUntilMetadataIsPropagated(servers, topic2, 2) val metadata = ClientUtils.fetchTopicMetadata(Set(topic2), brokers, "AddPartitionsTest-testManualAssignmentOfReplicas", - 2000,0).topicsMetadata + null, 2000,0).topicsMetadata val metaDataForTopic2 = metadata.filter(p => p.topic.equals(topic2)) val partitionDataForTopic2 = metaDataForTopic2.head.partitionsMetadata assertEquals(partitionDataForTopic2.size, 3) @@ -159,7 +159,7 @@ class AddPartitionsTest extends JUnit3Suite with ZooKeeperTestHarness { TestUtils.waitUntilMetadataIsPropagated(servers, topic3, 6) val metadata = ClientUtils.fetchTopicMetadata(Set(topic3), brokers, "AddPartitionsTest-testReplicaPlacement", - 2000,0).topicsMetadata + null, 2000,0).topicsMetadata val metaDataForTopic3 = metadata.filter(p => p.topic.equals(topic3)).head val partition1DataForTopic3 = metaDataForTopic3.partitionsMetadata(1) diff --git a/core/src/test/scala/unit/kafka/integration/TopicMetadataTest.scala b/core/src/test/scala/unit/kafka/integration/TopicMetadataTest.scala index 35dc071b1056e775326981573c9618d8046e601d..acdb03ffed575a1c3d164e662e0893ee4579dbcd 100644 --- a/core/src/test/scala/unit/kafka/integration/TopicMetadataTest.scala +++ b/core/src/test/scala/unit/kafka/integration/TopicMetadataTest.scala @@ -68,7 +68,7 @@ class TopicMetadataTest extends JUnit3Suite with ZooKeeperTestHarness { createTopic(zkClient, topic, numPartitions = 1, replicationFactor = 1, servers = Seq(server1)) var topicsMetadata = ClientUtils.fetchTopicMetadata(Set(topic),brokers,"TopicMetadataTest-testBasicTopicMetadata", - 2000,0).topicsMetadata + null, 2000,0).topicsMetadata assertEquals(ErrorMapping.NoError, topicsMetadata.head.errorCode) assertEquals(ErrorMapping.NoError, topicsMetadata.head.partitionsMetadata.head.errorCode) assertEquals("Expecting metadata only for 1 topic", 1, topicsMetadata.size) @@ -88,7 +88,7 @@ class TopicMetadataTest extends JUnit3Suite with ZooKeeperTestHarness { // issue metadata request with empty list of topics var topicsMetadata = ClientUtils.fetchTopicMetadata(Set.empty, brokers, "TopicMetadataTest-testGetAllTopicMetadata", - 2000, 0).topicsMetadata + null, 2000, 0).topicsMetadata assertEquals(ErrorMapping.NoError, topicsMetadata.head.errorCode) assertEquals(2, topicsMetadata.size) assertEquals(ErrorMapping.NoError, topicsMetadata.head.partitionsMetadata.head.errorCode) @@ -107,7 +107,7 @@ class TopicMetadataTest extends JUnit3Suite with ZooKeeperTestHarness { // auto create topic val topic = "testAutoCreateTopic" var topicsMetadata = ClientUtils.fetchTopicMetadata(Set(topic),brokers,"TopicMetadataTest-testAutoCreateTopic", - 2000,0).topicsMetadata + null, 2000,0).topicsMetadata assertEquals(ErrorMapping.LeaderNotAvailableCode, topicsMetadata.head.errorCode) assertEquals("Expecting metadata only for 1 topic", 1, topicsMetadata.size) assertEquals("Expecting metadata for the test topic", topic, topicsMetadata.head.topic) @@ -119,7 +119,7 @@ class TopicMetadataTest extends JUnit3Suite with ZooKeeperTestHarness { // retry the metadata for the auto created topic topicsMetadata = ClientUtils.fetchTopicMetadata(Set(topic),brokers,"TopicMetadataTest-testBasicTopicMetadata", - 2000,0).topicsMetadata + null, 2000,0).topicsMetadata assertEquals(ErrorMapping.NoError, topicsMetadata.head.errorCode) assertEquals(ErrorMapping.NoError, topicsMetadata.head.partitionsMetadata.head.errorCode) var partitionMetadata = topicsMetadata.head.partitionsMetadata diff --git a/core/src/test/scala/unit/kafka/network/SocketServerTest.scala b/core/src/test/scala/unit/kafka/network/SocketServerTest.scala index 3b83a86a773147c9b29b4d271aee480efec748ad..ea649535e034744bb1119735c30f3c843002af6e 100644 --- a/core/src/test/scala/unit/kafka/network/SocketServerTest.scala +++ b/core/src/test/scala/unit/kafka/network/SocketServerTest.scala @@ -36,10 +36,12 @@ class SocketServerTest extends JUnitSuite { val server: SocketServer = new SocketServer(0, host = null, port = kafka.utils.TestUtils.choosePort, + secure = false, + securityConfig = null, numProcessorThreads = 1, maxQueuedRequests = 50, sendBufferSize = 300000, - recvBufferSize = 300000, + receiveBufferSize = 300000, maxRequestSize = 50, maxConnectionsPerIp = 5) server.startup() diff --git a/core/src/test/scala/unit/kafka/utils/TestUtils.scala b/core/src/test/scala/unit/kafka/utils/TestUtils.scala index c4e13c5240c8303853d08cc3b40088f8c7dae460..df7a9af59e0f581bde543d07a40a9f10ff44794f 100644 --- a/core/src/test/scala/unit/kafka/utils/TestUtils.scala +++ b/core/src/test/scala/unit/kafka/utils/TestUtils.scala @@ -437,7 +437,7 @@ object TestUtils extends Logging { def createBrokersInZk(zkClient: ZkClient, ids: Seq[Int]): Seq[Broker] = { val brokers = ids.map(id => new Broker(id, "localhost", 6667)) - brokers.foreach(b => ZkUtils.registerBrokerInZk(zkClient, b.id, b.host, b.port, 6000, jmxPort = -1)) + brokers.foreach(b => ZkUtils.registerBrokerInZk(zkClient, b.id, b.host, b.port, 6000, jmxPort = -1, false)) brokers } diff --git a/gradle/wrapper/gradle-wrapper.jar b/gradle/wrapper/gradle-wrapper.jar index a7634b071cb255e91a4572934e55b8cd8877b3e4..d5c591c9c532da774b062aa6d7ab16405ff8700a 100644 Binary files a/gradle/wrapper/gradle-wrapper.jar and b/gradle/wrapper/gradle-wrapper.jar differ diff --git a/gradlew b/gradlew index c312b910b570f46f0435302cda44295d8903c573..91a7e269e19dfc62e27137a0b57ef3e430cee4fd 100755 --- a/gradlew +++ b/gradlew @@ -7,7 +7,7 @@ ############################################################################## # Add default JVM options here. You can also use JAVA_OPTS and GRADLE_OPTS to pass JVM options to this script. -DEFAULT_JVM_OPTS="-Xmx1024m -Xms256m -XX:MaxPermSize=512m" +DEFAULT_JVM_OPTS="" APP_NAME="Gradle" APP_BASE_NAME=`basename "$0"` diff --git a/gradlew.bat b/gradlew.bat index 84974e20d1bef7ff9ef43933c913252189315458..8a0b282aa6885fb573c106b3551f7275c5f17e8e 100644 --- a/gradlew.bat +++ b/gradlew.bat @@ -9,7 +9,7 @@ if "%OS%"=="Windows_NT" setlocal @rem Add default JVM options here. You can also use JAVA_OPTS and GRADLE_OPTS to pass JVM options to this script. -set DEFAULT_JVM_OPTS=-Xmx1024m -Xms256m -XX:MaxPermSize=512m +set DEFAULT_JVM_OPTS= set DIRNAME=%~dp0 if "%DIRNAME%" == "" set DIRNAME=.