From 4a86fcc96b171f35793690f4bb2752065db4cc53 Mon Sep 17 00:00:00 2001 From: jholoman Date: Thu, 8 Jan 2015 11:25:36 -0500 Subject: [PATCH] KAFKA-1810 --- core/src/main/scala/kafka/network/IPFilter.scala | 123 +++++++++++ .../main/scala/kafka/network/SocketServer.scala | 81 +++---- core/src/main/scala/kafka/server/KafkaConfig.scala | 9 + core/src/main/scala/kafka/server/KafkaServer.scala | 5 +- .../scala/kafka/utils/VerifiableProperties.scala | 7 + .../scala/unit/kafka/network/IPFilterTest.scala | 234 +++++++++++++++++++++ .../unit/kafka/network/SocketServerTest.scala | 104 ++++++--- .../scala/unit/kafka/server/KafkaConfigTest.scala | 13 +- 8 files changed, 512 insertions(+), 64 deletions(-) create mode 100644 core/src/main/scala/kafka/network/IPFilter.scala create mode 100644 core/src/test/scala/unit/kafka/network/IPFilterTest.scala diff --git a/core/src/main/scala/kafka/network/IPFilter.scala b/core/src/main/scala/kafka/network/IPFilter.scala new file mode 100644 index 0000000..4369ffa --- /dev/null +++ b/core/src/main/scala/kafka/network/IPFilter.scala @@ -0,0 +1,123 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package kafka.network + +import java.math.BigInteger +import java.net.{InetAddress, UnknownHostException} +import kafka.common.KafkaException +import kafka.metrics.KafkaMetricsGroup +import kafka.network.IPFilter._ +import kafka.utils._ +import util.matching.Regex + +object IPFilter { + //Can move most of this to configDef if/when that happens + val validRules = "(allow|deny|none)" + val rgx = new Regex(validRules) + val AllowRule: Tuple2[String, String] = ("allow", "Whitelist") + val DenyRule: Tuple2[String, String] = ("deny", "Blacklist") + val NoRule = "none" + + def checkRule(ruleType: String): String = { + rgx.findFirstIn(ruleType) match { + case Some(r) => ruleType + case None => throw new IPFilterConfigException("Invalid rule type specified: " + ruleType) + } + } + def checkRangeAndRule(range: List[CIDRRange], ruleType:String) = { + if ((!(range.isEmpty) && ruleType == NoRule) || (range.isEmpty && ruleType != NoRule)) { + throw new IPFilterConfigException("Error processing IPFilter List: Both rule type and rules list must be set") + } + } +} + +class IPFilter(ipFilterList: List[String], ipFilterRuleType: String) extends Logging with KafkaMetricsGroup { + import IPFilter._ + + val ruleType = checkRule(ipFilterRuleType) + + val filterList: List[CIDRRange] = { + try { + ipFilterList.map(entry => new CIDRRange(entry)).toList + } catch { + case e: UnknownHostException => throw new IPFilterConfigException("Error processing IPFilter List " + e.getMessage) + } + } + + checkRangeAndRule(filterList,ruleType) + + def check(inetAddress: InetAddress): Boolean = { + ruleType match { + case AllowRule._1 => { + if (filterList.exists(_.contains(inetAddress))) + true + else throw new IPFilterException(inetAddress, AllowRule._2) + } + case DenyRule._1 => { + if (filterList.exists(_.contains(inetAddress))) + throw new IPFilterException(inetAddress, DenyRule._2) + else true + } + case _ => true + } + } +} + +class CIDRRange(val ipRange: String) { + + if (ipRange.indexOf("/") < 0) { + throw new UnknownHostException("Not a valid CIDR Range " + ipRange) + } + val inetAddress = { + InetAddress.getByName(ipRange.split("/").apply(0)) + } + private val prefixLen = inetAddress.getAddress.length + + val prefix = ipRange.split("/").apply(1).toInt + + val mask = getMask(prefix, prefixLen) + + /* bitwise "and" the mask for the low address. bitwise "add" the flipped mask for high address */ + val low: BigInt = new BigInt(new BigInteger(1, inetAddress.getAddress)).&(mask) + val high: BigInt = low.+(~mask) + + /* match for IPV4 or IPV6 (4 or 16) + * fill a BigInteger with all ones (-1 = 11111111) for each octet , flip the bits with not() + * bit-shift right by the length of the prefix + */ + def getMask(prefix: Int, prefixLen: Int): BigInt = { + if ((prefix < 0 || prefix > 128) || (prefix > 32 && prefixLen == 4)) { + throw new UnknownHostException("Not a valid prefix length " + prefix) + } + prefixLen match { + case x if x == 4 => new BigInteger(1, Array.fill[Byte](4)(-1)).not().shiftRight(prefix) + case x if x == 16 => new BigInteger(1, Array.fill[Byte](16)(-1)).not().shiftRight(prefix) + } + } + + def contains(inetAddress: InetAddress): Boolean = { + val ip: BigInt = new BigInteger(1, inetAddress.getAddress()) + ip >= low && ip <= high + } +} + +class IPFilterException(val ip: InetAddress, val ruleType: String) extends KafkaException("Rejected connection from %s (%s)".format(ip, ruleType)) + +class IPFilterConfigException(message: String) extends KafkaException(message) { + def this() = this(null) +} + diff --git a/core/src/main/scala/kafka/network/SocketServer.scala b/core/src/main/scala/kafka/network/SocketServer.scala index 39b1651..c79265e 100644 --- a/core/src/main/scala/kafka/network/SocketServer.scala +++ b/core/src/main/scala/kafka/network/SocketServer.scala @@ -23,7 +23,6 @@ import java.util.concurrent.atomic._ import java.net._ import java.io._ import java.nio.channels._ - import scala.collection._ import kafka.common.KafkaException @@ -47,7 +46,9 @@ class SocketServer(val brokerId: Int, val maxRequestSize: Int = Int.MaxValue, val maxConnectionsPerIp: Int = Int.MaxValue, val connectionsMaxIdleMs: Long, - val maxConnectionsPerIpOverrides: Map[String, Int] ) extends Logging with KafkaMetricsGroup { + val ipFilterRuleType: String, + val ipFilterList: List[String], + val maxConnectionsPerIpOverrides: Map[String, Int] = Map[String, Int]()) extends Logging with KafkaMetricsGroup { this.logIdent = "[Socket Server on Broker " + brokerId + "], " private val time = SystemTime private val processors = new Array[Processor](numProcessorThreads) @@ -62,16 +63,18 @@ class SocketServer(val brokerId: Int, */ def startup() { val quotas = new ConnectionQuotas(maxConnectionsPerIp, maxConnectionsPerIpOverrides) + val ipFilters = new IPFilter(ipFilterList, ipFilterRuleType) + for(i <- 0 until numProcessorThreads) { - processors(i) = new Processor(i, - time, - maxRequestSize, - aggregateIdleMeter, - newMeter("IdlePercent", "percent", TimeUnit.NANOSECONDS, Map("networkProcessor" -> i.toString)), - numProcessorThreads, - requestChannel, - quotas, - connectionsMaxIdleMs) + processors(i) = new Processor(i, + time, + maxRequestSize, + aggregateIdleMeter, + newMeter("IdlePercent", "percent", TimeUnit.NANOSECONDS, Map("networkProcessor" -> i.toString)), + numProcessorThreads, + requestChannel, + quotas, + connectionsMaxIdleMs) Utils.newThread("kafka-network-thread-%d-%d".format(port, i), processors(i), false).start() } @@ -81,9 +84,9 @@ class SocketServer(val brokerId: Int, // register the processor threads for notification of responses requestChannel.addResponseListener((id:Int) => processors(id).wakeup()) - + // start accepting connections - this.acceptor = new Acceptor(host, port, processors, sendBufferSize, recvBufferSize, quotas) + this.acceptor = new Acceptor(host, port, processors, sendBufferSize, recvBufferSize, quotas, ipFilters) Utils.newThread("kafka-socket-acceptor", acceptor, false).start() acceptor.awaitStartup info("Started") @@ -142,12 +145,12 @@ private[kafka] abstract class AbstractServerThread(connectionQuotas: ConnectionQ * Is the server still running? */ protected def isRunning = alive.get - + /** * Wakeup the thread for selection. */ def wakeup() = selector.wakeup() - + /** * Close the given key and associated socket */ @@ -158,7 +161,7 @@ private[kafka] abstract class AbstractServerThread(connectionQuotas: ConnectionQ swallowError(key.cancel()) } } - + def close(channel: SocketChannel) { if(channel != null) { debug("Closing connection from " + channel.socket.getRemoteSocketAddress()) @@ -167,13 +170,13 @@ private[kafka] abstract class AbstractServerThread(connectionQuotas: ConnectionQ swallowError(channel.close()) } } - + /** * Close all open connections */ def closeAll() { // removes cancelled keys from selector.keys set - this.selector.selectNow() + this.selector.selectNow() val iter = this.selector.keys().iterator() while (iter.hasNext) { val key = iter.next() @@ -196,13 +199,14 @@ private[kafka] abstract class AbstractServerThread(connectionQuotas: ConnectionQ /** * Thread that accepts and configures new connections. There is only need for one of these */ -private[kafka] class Acceptor(val host: String, - val port: Int, +private[kafka] class Acceptor(val host: String, + val port: Int, private val processors: Array[Processor], - val sendBufferSize: Int, + val sendBufferSize: Int, val recvBufferSize: Int, - connectionQuotas: ConnectionQuotas) extends AbstractServerThread(connectionQuotas) { - val serverChannel = openServerSocket(host, port) + connectionQuotas: ConnectionQuotas, + ipFilters: IPFilter ) extends AbstractServerThread(connectionQuotas) { // thinking I should refactor to one list maybe +val serverChannel = openServerSocket(host, port) /** * Accept loop that checks for new connection attempts @@ -222,9 +226,9 @@ private[kafka] class Acceptor(val host: String, key = iter.next iter.remove() if(key.isAcceptable) - accept(key, processors(currentProcessor)) + accept(key, processors(currentProcessor)) else - throw new IllegalStateException("Unrecognized key state for acceptor thread.") + throw new IllegalStateException("Unrecognized key state for acceptor thread.") // round robin to the next processor thread currentProcessor = (currentProcessor + 1) % processors.length @@ -239,12 +243,12 @@ private[kafka] class Acceptor(val host: String, swallowError(selector.close()) shutdownComplete() } - + /* * Create a server socket to listen for connections on. */ def openServerSocket(host: String, port: Int): ServerSocketChannel = { - val socketAddress = + val socketAddress = if(host == null || host.trim.isEmpty) new InetSocketAddress(port) else @@ -256,7 +260,7 @@ private[kafka] class Acceptor(val host: String, serverChannel.socket.bind(socketAddress) info("Awaiting socket connections on %s:%d.".format(socketAddress.getHostName, port)) } catch { - case e: SocketException => + case e: SocketException => throw new KafkaException("Socket server failed to bind to %s:%d: %s.".format(socketAddress.getHostName, port, e.getMessage), e) } serverChannel @@ -268,22 +272,27 @@ private[kafka] class Acceptor(val host: String, def accept(key: SelectionKey, processor: Processor) { val serverSocketChannel = key.channel().asInstanceOf[ServerSocketChannel] val socketChannel = serverSocketChannel.accept() + val address = socketChannel.socket().getInetAddress try { - connectionQuotas.inc(socketChannel.socket().getInetAddress) + connectionQuotas.inc(address) + ipFilters.check(address) socketChannel.configureBlocking(false) socketChannel.socket().setTcpNoDelay(true) socketChannel.socket().setSendBufferSize(sendBufferSize) debug("Accepted connection from %s on %s. sendBufferSize [actual|requested]: [%d|%d] recvBufferSize [actual|requested]: [%d|%d]" - .format(socketChannel.socket.getInetAddress, socketChannel.socket.getLocalSocketAddress, - socketChannel.socket.getSendBufferSize, sendBufferSize, - socketChannel.socket.getReceiveBufferSize, recvBufferSize)) + .format(socketChannel.socket.getInetAddress, socketChannel.socket.getLocalSocketAddress, + socketChannel.socket.getSendBufferSize, sendBufferSize, + socketChannel.socket.getReceiveBufferSize, recvBufferSize)) processor.accept(socketChannel) } catch { case e: TooManyConnectionsException => info("Rejected connection from %s, address already has the configured maximum of %d connections.".format(e.ip, e.count)) close(socketChannel) + case e: IPFilterException => + info(e.getMessage.format(e.ip) ) + close(socketChannel) } } @@ -508,7 +517,7 @@ private[kafka] class Processor(val id: Int, class ConnectionQuotas(val defaultMax: Int, overrideQuotas: Map[String, Int]) { private val overrides = overrideQuotas.map(entry => (InetAddress.getByName(entry._1), entry._2)) private val counts = mutable.Map[InetAddress, Int]() - + def inc(addr: InetAddress) { counts synchronized { val count = counts.getOrElse(addr, 0) @@ -518,7 +527,7 @@ class ConnectionQuotas(val defaultMax: Int, overrideQuotas: Map[String, Int]) { throw new TooManyConnectionsException(addr, max) } } - + def dec(addr: InetAddress) { counts synchronized { val count = counts.get(addr).get @@ -528,7 +537,7 @@ class ConnectionQuotas(val defaultMax: Int, overrideQuotas: Map[String, Int]) { counts.put(addr, count - 1) } } - -} +} class TooManyConnectionsException(val ip: InetAddress, val count: Int) extends KafkaException("Too many connections from %s (maximum = %d)".format(ip, count)) + diff --git a/core/src/main/scala/kafka/server/KafkaConfig.scala b/core/src/main/scala/kafka/server/KafkaConfig.scala index 6e26c54..448ac35 100644 --- a/core/src/main/scala/kafka/server/KafkaConfig.scala +++ b/core/src/main/scala/kafka/server/KafkaConfig.scala @@ -21,6 +21,7 @@ import java.util.Properties import kafka.message.{MessageSet, Message} import kafka.consumer.ConsumerConfig import kafka.utils.{VerifiableProperties, ZKConfig, Utils} +import kafka.network.IPFilter /** * Configuration settings for the kafka server @@ -127,6 +128,14 @@ class KafkaConfig private (val props: VerifiableProperties) extends ZKConfig(pro /* idle connections timeout: the server socket processor threads close the connections that idle more than this */ val connectionsMaxIdleMs = props.getLong("connections.max.idle.ms", 10*60*1000L) + /* the type of IP Filtering list to be evaluated, either "allow" (whitelist) or "deny" (blacklist) */ + val ipFilterRuleType = props.getString("security.ip.filter.rule.type",IPFilter.NoRule).toLowerCase + + /* IP Whitelist / Blacklist, specified in CIDR notation eg, 192.168.1.1/32, 192.168.2.1/24 + * /32 is a single IPv4 address, /128 is a single IPv6 address */ + val ipFilterList = props.getList("security.ip.filter.list") + + /*********** Log Configuration ***********/ /* the default number of log partitions per topic */ diff --git a/core/src/main/scala/kafka/server/KafkaServer.scala b/core/src/main/scala/kafka/server/KafkaServer.scala index 1691ad7..e2f99e2 100644 --- a/core/src/main/scala/kafka/server/KafkaServer.scala +++ b/core/src/main/scala/kafka/server/KafkaServer.scala @@ -85,7 +85,7 @@ class KafkaServer(val config: KafkaConfig, time: Time = SystemTime) extends Logg logManager = createLogManager(zkClient, brokerState) logManager.startup() - socketServer = new SocketServer(config.brokerId, + socketServer = new SocketServer(config.brokerId, config.hostName, config.port, config.numNetworkThreads, @@ -95,7 +95,8 @@ class KafkaServer(val config: KafkaConfig, time: Time = SystemTime) extends Logg config.socketRequestMaxBytes, config.maxConnectionsPerIp, config.connectionsMaxIdleMs, - config.maxConnectionsPerIpOverrides) + config.ipFilterRuleType, + config.ipFilterList) socketServer.startup() replicaManager = new ReplicaManager(config, time, zkClient, kafkaScheduler, logManager, isShuttingDown) diff --git a/core/src/main/scala/kafka/utils/VerifiableProperties.scala b/core/src/main/scala/kafka/utils/VerifiableProperties.scala index 2ffc7f4..c8e09dd 100644 --- a/core/src/main/scala/kafka/utils/VerifiableProperties.scala +++ b/core/src/main/scala/kafka/utils/VerifiableProperties.scala @@ -196,6 +196,13 @@ class VerifiableProperties(val props: Properties) extends Logging { } /** + * Get a List[String] from a property list in the form string,string,string + */ + def getList(name: String): List[String] = { + Utils.parseCsvList(getString(name, "")).toList + } + + /** * Parse compression codec from a property list in either. Codecs may be specified as integers, or as strings. * See [[kafka.message.CompressionCodec]] for more details. * @param name The property name diff --git a/core/src/test/scala/unit/kafka/network/IPFilterTest.scala b/core/src/test/scala/unit/kafka/network/IPFilterTest.scala new file mode 100644 index 0000000..5d4268b --- /dev/null +++ b/core/src/test/scala/unit/kafka/network/IPFilterTest.scala @@ -0,0 +1,234 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package unit.kafka.network + +import java.net._ +import kafka.network.{IPFilterConfigException, CIDRRange, IPFilterException, IPFilter} +import org.junit.Assert._ +import org.junit.Test +import kafka.utils.Utils + +class IPFilterTest { + @Test + def testEmptyList(): Unit = { + val range = Utils.parseCsvList("").toList + val ipFilters = new IPFilter(range, IPFilter.NoRule) + + val ip = InetAddress.getByName("192.168.2.1") + assertTrue(ipFilters.check(ip)) + } + + @Test + def testRuleType(): Unit = { + val range: List[String] = List("192.168.1.1/32") + val range2: List[String] = List.empty + val ruleType = "Allowe" // misspelled + val ruleTypeNone = IPFilter.NoRule //default + val ruleTypeAllow = IPFilter.AllowRule._1 + val ruleTypeDeny = IPFilter.DenyRule._1 + val ipFilter1 = new IPFilter(range2, ruleTypeNone) + val ipFilter2 = new IPFilter(range, ruleTypeAllow) + val ipFilter3 = new IPFilter(range, ruleTypeDeny) + + try { + val ipFilter = new IPFilter(range, ruleType) + fail() + } catch { + case e: IPFilterConfigException => + } + assertEquals(ipFilter1.ruleType, IPFilter.NoRule) + assertEquals(ipFilter2.ruleType, IPFilter.AllowRule._1) + assertEquals(ipFilter3.ruleType, IPFilter.DenyRule._1) + } + + @Test + def testRangeAndRuleType(): Unit = { + val range = List.empty[String] + val ruleType = IPFilter.AllowRule._1 + val range2 = List("192.168.2.1/32") + val ruleType2 = IPFilter.NoRule + + try { + val ipFilter = new IPFilter(range, ruleType) + fail() + } catch { + case e: IPFilterConfigException => + } + + try { + val ipFilter = new IPFilter(range2, ruleType2) + fail() + } catch { + case e: IPFilterConfigException => + } + } + + @Test + def testBadPrefix: Unit = { + val range = List(List("192.168.2.1/-1"), List("192.168.2.1/64")) + for (l <- range) { + try { + new IPFilter(l, IPFilter.NoRule) + fail() + } catch { + case e: IPFilterConfigException => + } + } + } + + @Test + // Checks that the IP address is in a given range + def testIpV4Range(): Unit = { + val ipRange: String = "192.168.2.0/25" // 0-127 + val cidr = new CIDRRange(ipRange) + val ip1 = InetAddress.getByName("192.168.2.1") + val ip2 = InetAddress.getByName("192.168.2.128") + assertTrue(cidr.contains(ip1)) + assertFalse(cidr.contains(ip2)) + } + + @Test + def testIpV6Range1(): Unit = { + val ipRange: String = "fe80:0:0:0:202:b3ff:fe1e:8320/124" + val cidr = new CIDRRange(ipRange) + val ip1 = InetAddress.getByName("fe80:0:0:0:202:b3ff:fe1e:8320") + val ip2 = InetAddress.getByName("fe80:0:0:0:202:b3ff:fe1e:833f") + assertTrue(cidr.contains(ip1)) + assertFalse(cidr.contains(ip2)) + } + + @Test + def testIPV6Range2(): Unit = { + val ipRange: String = "fe80:0:0:0:202:b3ff:fe1e:8320/64" + val cidr = new CIDRRange(ipRange) + val ip1 = InetAddress.getByName("fe80:0000:0000:0000:0202:b3ff:fe1e:8320") + val ip2 = InetAddress.getByName("fe80:0000:0000:0000:ffff:ffff:ffff:ffff") + val ip3 = InetAddress.getByName("fe80:0:0:3:ffff:ffff:ffff:ffff") + assertTrue(cidr.contains(ip1)) + assertTrue(cidr.contains(ip2)) + assertFalse(cidr.contains(ip3)) + } + + // This is kind of a bogus test but tests the logic for ranges < 32 and IPv6 + @Test + def testIPV6Range3(): Unit = { + val ipRange: String = "fe80:0:0:0:202:b3ff:fe1e:8320/12" + val cidr = new CIDRRange(ipRange) + val ip1 = InetAddress.getByName("fe80:0000:0000:0000:0202:b3ff:fe1e:8320") + val ip2 = InetAddress.getByName("fe8f:ffff:ffff:ffff:ffff:ffff:ffff:ffff") + val ip3 = InetAddress.getByName("fe9f:ffff:ffff:ffff:ffff:ffff:ffff:ffff") + assertTrue(cidr.contains(ip1)) + assertTrue(cidr.contains(ip2)) + assertFalse(cidr.contains(ip3)) + } + + @Test + def testSingleBlackList(): Unit = { + val range1 = List("192.168.2.10/32") + val ipFilters = new IPFilter(range1, IPFilter.DenyRule._1) + val ip = InetAddress.getByName("192.168.2.10") + try { + ipFilters.check(ip) + fail() + } catch { + case e: IPFilterException => // this is good + } + val ip2 = InetAddress.getByName("192.168.2.1") + assertTrue(ipFilters.check(ip2)) + } + + @Test + def testMultipleBlackListEntries(): Unit = { + val range = List("192.168.2.0/28", "192.168.2.25/28") + val ipFilters = new IPFilter(range, IPFilter.DenyRule._1) + val ip1 = InetAddress.getByName("192.168.2.3") + val ip2 = InetAddress.getByName("192.168.2.26") + val ip3 = InetAddress.getByName("192.162.1.1") + try { + ipFilters.check(ip1) + fail() + } catch { + case e: IPFilterException => // this is good + } + try { + ipFilters.check(ip2) + fail() + } catch { + case e: IPFilterException => // this is good + } + assertTrue(ipFilters.check(ip3)) + } + + @Test + def testSingleWhiteList(): Unit = { + val range1 = List("192.168.2.10/32") + val ipFilters = new IPFilter(range1, IPFilter.AllowRule._1) + val ip = InetAddress.getByName("192.168.2.10") + val ip2 = InetAddress.getByName("192.168.2.1") + assertTrue(ipFilters.check(ip)) + try { + ipFilters.check(ip2) + fail() + } catch { + case e: IPFilterException => // this is good + } + } + + @Test + def testMultipleWhiteListEntries(): Unit = { + val range = List("192.168.2.0/24", "10.10.10.0/16") + val ipFilters = new IPFilter(range, IPFilter.AllowRule._1) + val ip1 = InetAddress.getByName("192.168.2.128") + val ip2 = InetAddress.getByName("192.168.1.128") + val ip3 = InetAddress.getByName("10.10.1.1") + val ip4 = InetAddress.getByName("10.9.1.1") + assertTrue(ipFilters.check(ip1)) + try { + ipFilters.check(ip2) + fail() + } catch { + case e: IPFilterException => // this is good + } + assertTrue(ipFilters.check(ip3)) + try { + ipFilters.check(ip4) + fail() + } catch { + case e: IPFilterException => // this is good + } + } + + @Test + def testRangeFormat(): Unit = { + val ruleType = IPFilter.AllowRule._1 + val range1 = List("192.168.2") + val range2 = List("192.168.1.2/32", "10.A.B.C/AAAA") + val range3 = List("blahblahblah:") + val range4 = List("192aaaa:") + val rangeList: List[List[String]] = List(range1, range2, range2, range4) + + for (l <- rangeList) + try { + val ipFilters = new IPFilter(l, ruleType) + fail() + } catch { + case e: IPFilterConfigException => + } + } +} + + diff --git a/core/src/test/scala/unit/kafka/network/SocketServerTest.scala b/core/src/test/scala/unit/kafka/network/SocketServerTest.scala index 78b431f..f04e0cc 100644 --- a/core/src/test/scala/unit/kafka/network/SocketServerTest.scala +++ b/core/src/test/scala/unit/kafka/network/SocketServerTest.scala @@ -34,18 +34,23 @@ import scala.collection.Map class SocketServerTest extends JUnitSuite { - val server: SocketServer = new SocketServer(0, - host = null, - port = kafka.utils.TestUtils.choosePort, - numProcessorThreads = 1, - maxQueuedRequests = 50, - sendBufferSize = 300000, - recvBufferSize = 300000, - maxRequestSize = 50, - maxConnectionsPerIp = 5, - connectionsMaxIdleMs = 60*1000, - maxConnectionsPerIpOverrides = Map.empty[String,Int]) - server.startup() + def getSocketServer(brokerId: Int = 0, + host: String = null, + port: Int = kafka.utils.TestUtils.choosePort, + numProcessorThreads: Int = 1, + maxQueuedRequests: Int = 50, + sendBufferSize: Int = 300000, + recvBufferSize: Int = 300000, + maxRequestSize: Int = 50, + maxConnectionsPerIp: Int = 5, + connectionsMaxIdleMs: Int = 60 * 1000, + ipFilterRuleType: String = IPFilter.NoRule, + ipFilterList: List[String] = List.empty[String], + maxConnectionsPerIpOverrides: Map[String, Int] = Map.empty[String, Int]): SocketServer = { + + new SocketServer(brokerId, host, port, numProcessorThreads, maxQueuedRequests, sendBufferSize, recvBufferSize, maxRequestSize, + maxConnectionsPerIp,connectionsMaxIdleMs,ipFilterRuleType,ipFilterList,maxConnectionsPerIpOverrides) + } def sendRequest(socket: Socket, id: Short, request: Array[Byte]) { val outgoing = new DataOutputStream(socket.getOutputStream) @@ -75,6 +80,9 @@ class SocketServerTest extends JUnitSuite { def connect(s:SocketServer = server) = new Socket("localhost", s.port) + val server:SocketServer = getSocketServer() + server.startup() + @After def cleanup() { server.shutdown() @@ -154,20 +162,10 @@ class SocketServerTest extends JUnitSuite { } @Test - def testMaxConnectionsPerIPOverrides(): Unit = { + def testMaxConnectionsPerIPOverrides() = { val overrideNum = 6 val overrides: Map[String, Int] = Map("localhost" -> overrideNum) - val overrideServer: SocketServer = new SocketServer(0, - host = null, - port = kafka.utils.TestUtils.choosePort, - numProcessorThreads = 1, - maxQueuedRequests = 50, - sendBufferSize = 300000, - recvBufferSize = 300000, - maxRequestSize = 50, - maxConnectionsPerIp = 5, - connectionsMaxIdleMs = 60*1000, - maxConnectionsPerIpOverrides = overrides) + val overrideServer: SocketServer = getSocketServer(maxConnectionsPerIpOverrides = overrides) overrideServer.startup() // make the maximum allowable number of connections and then leak them val conns = ((0 until overrideNum).map(i => connect(overrideServer))) @@ -177,4 +175,62 @@ class SocketServerTest extends JUnitSuite { assertEquals(-1, conn.getInputStream.read()) overrideServer.shutdown() } + + @Test + def testIPFilterBlackList() = { + val filterList: List[String] = List("localhost/32") + val ipFilterServer = getSocketServer(ipFilterRuleType = IPFilter.DenyRule._1, ipFilterList = filterList) + ipFilterServer.startup() + val conn = connect(ipFilterServer) + conn.setSoTimeout(3000) + assertEquals(-1, conn.getInputStream.read()) + ipFilterServer.shutdown() + } + + @Test + def testIPFilterWhiteList() = { + val filterList: List[String] = List("localhost/32") + val ipFilterServer = getSocketServer(ipFilterRuleType = IPFilter.AllowRule._1, ipFilterList = filterList) + ipFilterServer.startup() + val conn = connect(ipFilterServer) + conn.setSoTimeout(3000) + try { + conn.getInputStream.read() + fail() + } catch { + case e: SocketTimeoutException => // Means the socket is still open + } + ipFilterServer.shutdown() + } + + @Test + def testIPFilterNone() = { + val filterList: List[String] = List.empty + val ipFilterServer = getSocketServer(ipFilterRuleType = IPFilter.NoRule, ipFilterList = filterList) + ipFilterServer.startup() + val conn = connect(ipFilterServer) + conn.setSoTimeout(3000) + try { + conn.getInputStream.read() + fail() + } catch { + case e: SocketTimeoutException => //Means the socket is still open + } + ipFilterServer.shutdown() + } + + @Test + def testIPFilterMisConfig() = { + val filterList: List[String] = List("localhostblahblahblah/32") + val ipFilterServer = getSocketServer(ipFilterRuleType = IPFilter.AllowRule._1, ipFilterList = filterList) + try { + ipFilterServer.startup() + fail() + } catch { + case e: IPFilterConfigException => + } + // Don't need a shutdown because it will shutdown with test + } + } + diff --git a/core/src/test/scala/unit/kafka/server/KafkaConfigTest.scala b/core/src/test/scala/unit/kafka/server/KafkaConfigTest.scala index 2377abe..c095e38 100644 --- a/core/src/test/scala/unit/kafka/server/KafkaConfigTest.scala +++ b/core/src/test/scala/unit/kafka/server/KafkaConfigTest.scala @@ -21,7 +21,7 @@ import org.junit.Test import junit.framework.Assert._ import org.scalatest.junit.JUnit3Suite import kafka.utils.TestUtils - +import kafka.network.IPFilter class KafkaConfigTest extends JUnit3Suite { @Test @@ -180,6 +180,15 @@ class KafkaConfigTest extends JUnit3Suite { assertEquals(24 * 7 * 60L * 60L * 1000L, cfg.logRollTimeMillis ) } - + + @Test + def testIPFilterRules(): Unit = { + val props = TestUtils.createBrokerConfig(0,8181) + val cfg = new KafkaConfig(props) + assertEquals(IPFilter.NoRule, cfg.ipFilterRuleType) + assertEquals(List.empty[String], cfg.ipFilterList) + } + + } -- 1.8.4