Index: core/src/test/java/org/apache/hama/ipc/TestAsyncIPC.java =================================================================== --- core/src/test/java/org/apache/hama/ipc/TestAsyncIPC.java (리비전 0) +++ core/src/test/java/org/apache/hama/ipc/TestAsyncIPC.java (리비전 0) @@ -0,0 +1,227 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.hama.ipc; + +import java.io.IOException; +import java.net.InetSocketAddress; +import java.util.Random; + +import junit.framework.TestCase; + +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; +import org.apache.hadoop.conf.Configuration; +import org.apache.hadoop.io.LongWritable; +import org.apache.hadoop.io.Writable; +import org.apache.hadoop.util.StringUtils; + +public class TestAsyncIPC extends TestCase { + private static int sendindCount = 0; + private static int recievedCount = 0; + public static final Log LOG = LogFactory.getLog(TestAsyncIPC.class); + + final private static Configuration conf = new Configuration(); + final static private int PING_INTERVAL = 1000; + + static { + Client.setPingInterval(conf, PING_INTERVAL); + } + + public TestAsyncIPC(String name) { + super(name); + } + + private static final Random RANDOM = new Random(); + + private static final String ADDRESS = "0.0.0.0"; + private static int port = 7000; + + private static class TestServer extends AsyncServer { + private boolean sleep; + + + public TestServer(int handlerCount, boolean sleep) throws IOException { + super(ADDRESS, port++, LongWritable.class, handlerCount, conf); + this.sleep = sleep; + } + + @Override + public Writable call(Class protocol, Writable param, long receiveTime) + throws IOException { + if (sleep) { + try { + Thread.sleep(RANDOM.nextInt(2 * PING_INTERVAL)); // sleep a bit + } catch (InterruptedException e) { + } + } + return param; // echo param as result + } + } + + private static class SerialCaller extends Thread { + private AsyncClient client; + private InetSocketAddress serverAddress; + private int count; + private boolean failed; + + + public SerialCaller(AsyncClient client, InetSocketAddress server, + int count) { + this.client = client; + this.serverAddress = server; + this.count = count; + } + + @Override + public void run() { + for (int i = 0; i < count; i++) { + try { + LongWritable param = new LongWritable(RANDOM.nextLong()); + + LongWritable value = (LongWritable) client.call(param, serverAddress, + null, null, 0, conf); + if (!param.equals(value)) { + LOG.fatal("Call failed!"); + failed = true; + break; + } else { + successCount(); + } + + } catch (Exception e) { + LOG.fatal("Caught: " + StringUtils.stringifyException(e)); + failed = true; + } + } + } + } + + public static synchronized void successCount() { + ++recievedCount; + } + + private static class ParallelCaller extends Thread { + private AsyncClient client; + private int count; + private InetSocketAddress[] addresses; + private boolean failed; + + public ParallelCaller(AsyncClient client, + InetSocketAddress[] addresses, int count) { + this.client = client; + this.addresses = addresses; + this.count = count; + } + + @Override + public void run() { + for (int i = 0; i < count; i++) { + try { + Writable[] params = new Writable[addresses.length]; + for (int j = 0; j < addresses.length; j++) + params[j] = new LongWritable(RANDOM.nextLong()); + Writable[] values = client.call(params, addresses, null, null, conf); + + for (int j = 0; j < addresses.length; j++) { + if (!params[j].equals(values[j])) { + LOG.fatal("Call failed!"); + failed = true; + break; + } + } + } catch (Exception e) { + LOG.fatal("Caught: " + StringUtils.stringifyException(e)); + failed = true; + } + } + } + } + + public void testSerial() throws Exception { + testSerial(3, false, 2, 5, 100); + } + + public void testSerial(int handlerCount, boolean handlerSleep, + int clientCount, int callerCount, int callCount) throws Exception { + AsyncServer server = new TestServer(handlerCount, handlerSleep); + InetSocketAddress addr = server.getAddress(); + server.start(); + + AsyncClient[] clients = new AsyncClient[clientCount]; + for (int i = 0; i < clientCount; i++) { + clients[i] = new AsyncClient(LongWritable.class, conf); + } + + SerialCaller[] callers = new SerialCaller[callerCount]; + for (int i = 0; i < callerCount; i++) { + callers[i] = new SerialCaller(clients[i % clientCount], addr, callCount); + callers[i].start(); + } + for (int i = 0; i < callerCount; i++) { + callers[i].join(); + assertFalse(callers[i].failed); + } + for (int i = 0; i < clientCount; i++) { + clients[i].stop(); + } + System.out.println("Serial Sending Cound/Recieved Count : " + sendindCount + + "/" + recievedCount); + server.stop(); + } + + public void testParallel() throws Exception { + testParallel(10, false, 2, 4, 2, 4, 100); + } + + public void testParallel(int handlerCount, boolean handlerSleep, + int serverCount, int addressCount, int clientCount, int callerCount, + int callCount) throws Exception { + AsyncServer[] servers = new AsyncServer[serverCount]; + for (int i = 0; i < serverCount; i++) { + servers[i] = new TestServer(handlerCount, handlerSleep); + servers[i].start(); + } + + InetSocketAddress[] addresses = new InetSocketAddress[addressCount]; + for (int i = 0; i < addressCount; i++) { + addresses[i] = servers[i % serverCount].address; + } + + AsyncClient[] clients = new AsyncClient[clientCount]; + for (int i = 0; i < clientCount; i++) { + clients[i] = new AsyncClient(LongWritable.class, conf); + } + + ParallelCaller[] callers = new ParallelCaller[callerCount]; + for (int i = 0; i < callerCount; i++) { + callers[i] = new ParallelCaller(clients[i % clientCount], addresses, + callCount); + callers[i].start(); + } + for (int i = 0; i < callerCount; i++) { + callers[i].join(); + assertFalse(callers[i].failed); + } + for (int i = 0; i < clientCount; i++) { + clients[i].stop(); + } + for (int i = 0; i < serverCount; i++) { + servers[i].stop(); + } + } +} Index: core/src/test/java/org/apache/hama/ipc/TestAsyncRPC.java =================================================================== --- core/src/test/java/org/apache/hama/ipc/TestAsyncRPC.java (리비전 0) +++ core/src/test/java/org/apache/hama/ipc/TestAsyncRPC.java (리비전 0) @@ -0,0 +1,181 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.hama.ipc; + +import java.io.IOException; +import java.lang.reflect.Method; +import java.net.InetSocketAddress; +import java.util.Arrays; + +import junit.framework.TestCase; + +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; +import org.apache.hadoop.conf.Configuration; +import org.apache.hadoop.io.Writable; + +public class TestAsyncRPC extends TestCase { + private static final int PORT = 1234; + private static final String ADDRESS = "0.0.0.0"; + + public static final Log LOG = LogFactory + .getLog("org.apache.hama.ipc.TestRPCWithNetty"); + + private static Configuration conf = new Configuration(); + + public TestAsyncRPC(String name) { + super(name); + } + + + public interface TestProtocol extends VersionedProtocol { + public static final long versionID = 1L; + + + void ping() throws IOException; + + String echo(String value) throws IOException; + + String[] echo(String[] value) throws IOException; + + Writable echo(Writable value) throws IOException; + + int add(int v1, int v2) throws IOException; + + int add(int[] values) throws IOException; + + int error() throws IOException; + + void testServerGet() throws IOException; + } + + public class TestImpl implements TestProtocol { + + @Override + public long getProtocolVersion(String protocol, long clientVersion) { + return TestProtocol.versionID; + } + + @Override + public void ping() { + } + + @Override + public String echo(String value) throws IOException { + return value; + } + + @Override + public String[] echo(String[] values) throws IOException { + return values; + } + + @Override + public Writable echo(Writable writable) { + return writable; + } + + @Override + public int add(int v1, int v2) { + return v1 + v2; + } + + @Override + public int add(int[] values) { + int sum = 0; + for (int i = 0; i < values.length; i++) { + sum += values[i]; + } + return sum; + } + + @Override + public int error() throws IOException { + throw new IOException("bobo"); + } + + @Override + public void testServerGet() throws IOException { + AsyncServer server = AsyncServer.get(); + if (!(server instanceof AsyncRPC.NioServer)) { + throw new IOException("ServerWithNetty.get() failed"); + } + } + } + + + public void testCalls() throws Exception { + AsyncServer server = AsyncRPC.getServer(new TestImpl(), ADDRESS, + PORT, conf); + server.start(); + + InetSocketAddress addr = new InetSocketAddress(PORT); + TestProtocol proxy = (TestProtocol) AsyncRPC.getProxy( + TestProtocol.class, TestProtocol.versionID, addr, conf); + + proxy.ping(); + + String stringResult = proxy.echo("foo"); + assertEquals(stringResult, "foo"); + + stringResult = proxy.echo((String) null); + assertEquals(stringResult, null); + + String[] stringResults = proxy.echo(new String[] { "foo", "bar" }); + assertTrue(Arrays.equals(stringResults, new String[] { "foo", "bar" })); + + stringResults = proxy.echo((String[]) null); + assertTrue(Arrays.equals(stringResults, null)); + + int intResult = proxy.add(1, 2); + assertEquals(intResult, 3); + + intResult = proxy.add(new int[] { 1, 2 }); + assertEquals(intResult, 3); + + boolean caught = false; + try { + proxy.error(); + } catch (Exception e) { + LOG.debug("Caught " + e); + caught = true; + } + assertTrue(caught); + + proxy.testServerGet(); + + // try some multi-calls + Method echo = TestProtocol.class.getMethod("echo", + new Class[] { String.class }); + String[] strings = (String[]) AsyncRPC.call(echo, new String[][] { + { "a" }, { "b" } }, new InetSocketAddress[] { addr, addr }, null, conf); + assertTrue(Arrays.equals(strings, new String[] { "a", "b" })); + + Method ping = TestProtocol.class.getMethod("ping", new Class[] {}); + Object[] voids = AsyncRPC.call(ping, new Object[][] { {}, {} }, + new InetSocketAddress[] { addr, addr }, null, conf); + assertEquals(voids, null); + + server.stop(); + } + + public static void main(String[] args) throws Exception { + new TestAsyncRPC("test").testCalls(); + } + +} Index: core/src/test/java/org/apache/hama/bsp/message/TestHamaAsyncMessageManager.java =================================================================== --- core/src/test/java/org/apache/hama/bsp/message/TestHamaAsyncMessageManager.java (리비전 0) +++ core/src/test/java/org/apache/hama/bsp/message/TestHamaAsyncMessageManager.java (리비전 0) @@ -0,0 +1,114 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.hama.bsp.message; + +import java.net.InetSocketAddress; +import java.util.Iterator; +import java.util.Map.Entry; + +import junit.framework.TestCase; + +import org.apache.hadoop.fs.FileSystem; +import org.apache.hadoop.io.IntWritable; +import org.apache.hadoop.io.NullWritable; +import org.apache.hama.Constants; +import org.apache.hama.HamaConfiguration; +import org.apache.hama.bsp.BSPMessageBundle; +import org.apache.hama.bsp.BSPPeer; +import org.apache.hama.bsp.BSPPeerImpl; +import org.apache.hama.bsp.Counters; +import org.apache.hama.bsp.TaskAttemptID; +import org.apache.hama.bsp.message.queue.DiskQueue; +import org.apache.hama.bsp.message.queue.MemoryQueue; +import org.apache.hama.bsp.message.queue.MessageQueue; +import org.apache.hama.util.BSPNetUtils; + +public class TestHamaAsyncMessageManager extends TestCase { + + public static final String TMP_OUTPUT_PATH = "/tmp/messageQueue"; + // increment is here to solve race conditions in parallel execution to choose + // other ports. + public static volatile int increment = 1; + + public void testMemoryMessaging() throws Exception { + HamaConfiguration conf = new HamaConfiguration(); + conf.setClass(MessageManager.RECEIVE_QUEUE_TYPE_CLASS, MemoryQueue.class, + MessageQueue.class); + conf.set(DiskQueue.DISK_QUEUE_PATH_KEY, TMP_OUTPUT_PATH); + messagingInternal(conf); + } + + public void testDiskMessaging() throws Exception { + HamaConfiguration conf = new HamaConfiguration(); + conf.set(DiskQueue.DISK_QUEUE_PATH_KEY, TMP_OUTPUT_PATH); + conf.setClass(MessageManager.RECEIVE_QUEUE_TYPE_CLASS, DiskQueue.class, + MessageQueue.class); + messagingInternal(conf); + } + + private static void messagingInternal(HamaConfiguration conf) + throws Exception { + conf.set(MessageManagerFactory.MESSAGE_MANAGER_CLASS, + "org.apache.hama.bsp.message.HamaAsyncMessageManagerImpl"); + MessageManager messageManager = MessageManagerFactory + .getMessageManager(conf); + + assertTrue(messageManager instanceof HamaAsyncMessageManagerImpl); + + InetSocketAddress peer = new InetSocketAddress( + BSPNetUtils.getCanonicalHostname(), BSPNetUtils.getFreePort() + + (increment++)); + conf.set(Constants.PEER_HOST, Constants.DEFAULT_PEER_HOST); + conf.setInt(Constants.PEER_PORT, Constants.DEFAULT_PEER_PORT); + + BSPPeer dummyPeer = new BSPPeerImpl( + conf, FileSystem.get(conf), new Counters()); + TaskAttemptID id = new TaskAttemptID("1", 1, 1, 1); + messageManager.init(id, dummyPeer, conf, peer); + peer = messageManager.getListenerAddress(); + String peerName = peer.getHostName() + ":" + peer.getPort(); + System.out.println("Peer is " + peerName); + messageManager.send(peerName, new IntWritable(1337)); + + Iterator>> messageIterator = messageManager + .getOutgoingBundles(); + + Entry> entry = messageIterator + .next(); + + assertEquals(entry.getKey(), peer); + + assertTrue(entry.getValue().size() == 1); + + BSPMessageBundle bundle = new BSPMessageBundle(); + Iterator it = entry.getValue().iterator(); + while (it.hasNext()) { + bundle.addMessage(it.next()); + } + + messageManager.transfer(peer, bundle); + + messageManager.clearOutgoingMessages(); + + assertTrue(messageManager.getNumCurrentMessages() == 1); + IntWritable currentMessage = messageManager.getCurrentMessage(); + + assertEquals(currentMessage.get(), 1337); + messageManager.close(); + } +} Index: core/src/main/java/org/apache/hama/ipc/AsyncServer.java =================================================================== --- core/src/main/java/org/apache/hama/ipc/AsyncServer.java (리비전 0) +++ core/src/main/java/org/apache/hama/ipc/AsyncServer.java (리비전 0) @@ -0,0 +1,621 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.hama.ipc; + +import io.netty.bootstrap.ServerBootstrap; +import io.netty.buffer.ByteBuf; +import io.netty.channel.ChannelFuture; +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.ChannelInboundHandlerAdapter; +import io.netty.channel.ChannelInitializer; +import io.netty.channel.ChannelOption; +import io.netty.channel.ChannelPipeline; +import io.netty.channel.EventLoopGroup; +import io.netty.channel.nio.NioEventLoopGroup; +import io.netty.channel.socket.SocketChannel; +import io.netty.channel.socket.nio.NioServerSocketChannel; +import io.netty.handler.logging.LogLevel; +import io.netty.handler.logging.LoggingHandler; +import io.netty.handler.ssl.SslContext; +import io.netty.handler.ssl.util.SelfSignedCertificate; +import io.netty.util.ReferenceCountUtil; + +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; +import java.io.DataInputStream; +import java.io.DataOutputStream; +import java.io.IOException; +import java.net.InetSocketAddress; +import java.nio.ByteBuffer; +import java.nio.channels.WritableByteChannel; +import java.util.Collections; +import java.util.HashSet; +import java.util.Map; +import java.util.Set; +import java.util.concurrent.ConcurrentHashMap; + +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; +import org.apache.hadoop.conf.Configuration; +import org.apache.hadoop.io.IOUtils; +import org.apache.hadoop.io.Writable; +import org.apache.hadoop.io.WritableUtils; +import org.apache.hadoop.security.SaslRpcServer.AuthMethod; +import org.apache.hadoop.security.UserGroupInformation; +import org.apache.hadoop.security.token.SecretManager; +import org.apache.hadoop.security.token.TokenIdentifier; +import org.apache.hadoop.util.ReflectionUtils; +import org.apache.hadoop.util.StringUtils; + +/** + * An abstract IPC service using netty. IPC calls take a single {@link Writable} + * as a parameter, and return a {@link Writable}* + * + * @see AsyncServer + */ +public abstract class AsyncServer { + + private AuthMethod authMethod; + private static final ByteBuffer HEADER = ByteBuffer.wrap("hrpc".getBytes()); + static int INITIAL_RESP_BUF_SIZE = 10240; + UserGroupInformation user = null; + // 1 : Introduce ping and server does not throw away RPCs + // 3 : Introduce the protocol into the RPC connection header + // 4 : Introduced SASL security layer + private static final byte CURRENT_VERSION = 4; + // follows version is read. + private Configuration conf; + private final boolean tcpNoDelay; // if T then disable Nagle's Algorithm + private boolean isSecurityEnabled; + InetSocketAddress address; + private static final Log LOG = LogFactory.getLog(AsyncServer.class); + private static int NIO_BUFFER_LIMIT = 8 * 1024; + private final int maxRespSize; + static final String IPC_SERVER_RPC_MAX_RESPONSE_SIZE_KEY = "ipc.server.max.response.size"; + static final int IPC_SERVER_RPC_MAX_RESPONSE_SIZE_DEFAULT = 1024 * 1024; + + private static final ThreadLocal SERVER = new ThreadLocal(); + private int port; // port we listen on + private Class paramClass; // class of call parameters + // Configure the server.(constructor is thread num) + private EventLoopGroup bossGroup = new NioEventLoopGroup(1); + private EventLoopGroup workerGroup = new NioEventLoopGroup(); + private static final Map> PROTOCOL_CACHE = new ConcurrentHashMap>(); + private ExceptionsHandler exceptionsHandler = new ExceptionsHandler(); + + static Class getProtocolClass(String protocolName, Configuration conf) + throws ClassNotFoundException { + Class protocol = PROTOCOL_CACHE.get(protocolName); + if (protocol == null) { + protocol = conf.getClassByName(protocolName); + PROTOCOL_CACHE.put(protocolName, protocol); + } + return protocol; + } + + /** + * Getting address + * + * @return InetSocketAddress + */ + public InetSocketAddress getAddress() { + return address; + } + + /** + * Returns the server instance called under or null. May be called under + * {@link #call(Writable, long)} implementations, and under {@link Writable} + * methods of paramters and return values. Permits applications to access the + * server context. + * + * @return NioServer + */ + public static AsyncServer get() { + return SERVER.get(); + } + + /** + * Constructs a server listening on the named port and address. Parameters + * passed must be of the named class. The + * handlerCount determines + * the number of handler threads that will be used to process calls. + * + * @param bindAddress + * @param port + * @param paramClass + * @param handlerCount + * @param conf + * @throws IOException + */ + protected AsyncServer(String bindAddress, int port, + Class paramClass, int handlerCount, Configuration conf) + throws IOException { + this(bindAddress, port, paramClass, handlerCount, conf, Integer + .toString(port), null); + } + + protected AsyncServer(String bindAddress, int port, + Class paramClass, int handlerCount, + Configuration conf, String serverName) throws IOException { + this(bindAddress, port, paramClass, handlerCount, conf, serverName, null); + } + + protected AsyncServer(String bindAddress, int port, + Class paramClass, int handlerCount, + Configuration conf, String serverName, + SecretManager secretManager) + throws IOException { + this.conf = conf; + this.port = port; + this.address = new InetSocketAddress(bindAddress, port); + this.paramClass = paramClass; + this.maxRespSize = conf.getInt(IPC_SERVER_RPC_MAX_RESPONSE_SIZE_KEY, + IPC_SERVER_RPC_MAX_RESPONSE_SIZE_DEFAULT); + + this.isSecurityEnabled = UserGroupInformation.isSecurityEnabled(); + this.tcpNoDelay = conf.getBoolean("ipc.server.tcpnodelay", false); + } + + /** start server listener */ + public void start() { + new NioServerListener().start(); + } + + private class NioServerListener extends Thread { + + /** + * Configure and start nio server + */ + @Override + public void run() { + + // Configure SSL. + SERVER.set(AsyncServer.this); + final SslContext sslCtx; + try { + if (isSecurityEnabled) { + SelfSignedCertificate ssc = new SelfSignedCertificate(); + sslCtx = SslContext.newServerContext(ssc.certificate(), + ssc.privateKey()); + } else { + sslCtx = null; + } + // ServerBootstrap is a helper class that sets up a server + ServerBootstrap b = new ServerBootstrap(); + b.group(bossGroup, workerGroup).channel(NioServerSocketChannel.class) + .childHandler(new ChannelInitializer() { + @Override + public void initChannel(SocketChannel ch) throws Exception { + ChannelPipeline p = ch.pipeline(); + if (sslCtx != null) { + p.addLast(sslCtx.newHandler(ch.alloc())); + } + // Register message processing handler + p.addLast(new NioServerInboundHandler()); + } + }) + .childOption(ChannelOption.TCP_NODELAY, tcpNoDelay) + .childOption(ChannelOption.MAX_MESSAGES_PER_READ, NIO_BUFFER_LIMIT) + .handler(new LoggingHandler(LogLevel.INFO)); + + // Bind and start to accept incoming connections. + ChannelFuture f = b.bind(port).sync(); + // Wait until the server socket is closed. + f.channel().closeFuture().sync(); + } catch (Exception e) { + e.printStackTrace(); + } finally { + // Shut down Server gracefully + bossGroup.shutdownGracefully(); + workerGroup.shutdownGracefully(); + } + } + } + + /** Stops the server gracefully. */ + public void stop() { + if (bossGroup != null && !bossGroup.isTerminated()) { + bossGroup.shutdownGracefully(); + } + if (workerGroup != null && !workerGroup.isTerminated()) { + workerGroup.shutdownGracefully(); + } + LOG.info("server shutdown"); + } + + /** + * This class process received message from client and send response message. + */ + private class NioServerInboundHandler extends ChannelInboundHandlerAdapter { + ConnectionHeader header = new ConnectionHeader(); + Class protocol; + private String errorClass = null; + private String error = null; + private boolean rpcHeaderRead = false; // if initial rpc header is read + private boolean headerRead = false; // if the connection header that follows + // version is read. + + /** + * Be invoked only one when a connection is established and ready to + * generate traffic + * + * @param ctx + */ + @Override + public void channelActive(ChannelHandlerContext ctx) { + SERVER.set(AsyncServer.this); + } + + /** + * Process a recieved message from client. This method is called with the + * received message, whenever new data is received from a client. + * + * @param ctx + * @param cause + */ + @Override + public void channelRead(ChannelHandlerContext ctx, Object msg) { + ByteBuffer dataLengthBuffer = ByteBuffer.allocate(4); + ByteBuf byteBuf = (ByteBuf) msg; + + ByteBuffer data = null; + ByteBuffer rpcHeaderBuffer = null; + while (true) { + Call call = null; + errorClass = null; + error = null; + try { + if (dataLengthBuffer.remaining() > 0 && byteBuf.isReadable()) { + byteBuf.readBytes(dataLengthBuffer); + if (dataLengthBuffer.remaining() > 0 && byteBuf.isReadable()) { + return; + } + } else { + return; + } + + // read rpcHeader + if (!rpcHeaderRead) { + // Every connection is expected to send the header. + if (rpcHeaderBuffer == null) { + rpcHeaderBuffer = ByteBuffer.allocate(2); + } + byteBuf.readBytes(rpcHeaderBuffer); + if (!rpcHeaderBuffer.hasArray() || rpcHeaderBuffer.remaining() > 0) { + return; + } + int version = rpcHeaderBuffer.get(0); + byte[] method = new byte[] { rpcHeaderBuffer.get(1) }; + try { + authMethod = AuthMethod.read(new DataInputStream( + new ByteArrayInputStream(method))); + dataLengthBuffer.flip(); + } catch (IOException ioe) { + errorClass = ioe.getClass().getName(); + error = StringUtils.stringifyException(ioe); + } + + if (!HEADER.equals(dataLengthBuffer) || version != CURRENT_VERSION) { + return; + } + dataLengthBuffer.clear(); + if (authMethod == null) { + throw new RuntimeException("Unable to read authentication method"); + } + rpcHeaderBuffer = null; + rpcHeaderRead = true; + continue; + } + + // read data length and allocate buffer; + if (data == null) { + dataLengthBuffer.flip(); + int dataLength = dataLengthBuffer.getInt(); + if (dataLength < 0) { + LOG.warn("Unexpected data length " + dataLength + "!! from " + + address.getHostName()); + } + data = ByteBuffer.allocate(dataLength); + } + + // read received data + byteBuf.readBytes(data); + if (data.remaining() == 0) { + dataLengthBuffer.clear(); + data.flip(); + boolean isHeaderRead = headerRead; + call = processOneRpc(data.array()); + data = null; + if (!isHeaderRead) { + continue; + } + } + } catch (OutOfMemoryError e) { + // we can run out of memory if we have too many threads + // log the event and sleep for a minute and give + // some thread(s) a chance to finish + // + LOG.warn("Out of Memory in server select", e); + try { + Thread.sleep(60000); + } catch (Exception ie) { + } + } catch (Exception e) { + LOG.warn("Exception in Responder " + + StringUtils.stringifyException(e)); + break; + } + sendResponse(ctx, call); + } + ReferenceCountUtil.release(msg); + } + + /** + * Send response data to client + * + * @param ctx + * @param call + */ + private void sendResponse(ChannelHandlerContext ctx, Call call) { + ByteArrayOutputStream buf = new ByteArrayOutputStream( + INITIAL_RESP_BUF_SIZE); + Writable value = null; + try { + value = call(protocol, call.param, call.timestamp); + } catch (Throwable e) { + String logMsg = this.getClass().getName() + ", call " + call + + ": error: " + e; + if (e instanceof RuntimeException || e instanceof Error) { + // These exception types indicate something is probably wrong + // on the server side, as opposed to just a normal exceptional + // result. + LOG.warn(logMsg, e); + } else if (exceptionsHandler.isTerse(e.getClass())) { + // Don't log the whole stack trace of these exceptions. + // Way too noisy! + LOG.info(logMsg); + } else { + LOG.info(logMsg, e); + } + errorClass = e.getClass().getName(); + error = StringUtils.stringifyException(e); + } + try { + setupResponse(buf, call, (error == null) ? Status.SUCCESS + : Status.ERROR, value, errorClass, error); + if (buf.size() > maxRespSize) { + LOG.warn("Large response size " + buf.size() + " for call " + + call.toString()); + buf = new ByteArrayOutputStream(INITIAL_RESP_BUF_SIZE); + } + //send response data; + channelWrite(ctx, call.response); + } catch (Exception e) { + LOG.info(this.getClass().getName() + " caught: " + + StringUtils.stringifyException(e)); + error = null; + } finally { + IOUtils.closeStream(buf); + } + } + + /** + * read header or data + * @param buf + * @return + */ + private Call processOneRpc(byte[] buf) throws IOException { + if (headerRead) { + return processData(buf); + } else { + processHeader(buf); + headerRead = true; + return null; + } + } + + /** + * Reads the connection header following version + * + * @param buf buffer + */ + private void processHeader(byte[] buf) { + DataInputStream in = new DataInputStream(new ByteArrayInputStream(buf)); + try { + header.readFields(in); + String protocolClassName = header.getProtocol(); + if (protocolClassName != null) { + protocol = getProtocolClass(header.getProtocol(), conf); + } + } catch (Exception e) { + throw new RuntimeException(e); + } finally { + IOUtils.closeStream(in); + } + + UserGroupInformation protocolUser = header.getUgi(); + user = protocolUser; + } + + /** + * + * Reads the received data, create call object; + * @param buf buffer to serialize the response into + * @return the IPC Call + */ + private Call processData(byte[] buf) { + DataInputStream dis = new DataInputStream(new ByteArrayInputStream(buf)); + try { + int id = dis.readInt(); // try to read an id + + if (LOG.isDebugEnabled()) + LOG.debug(" got #" + id); + + Writable param = ReflectionUtils.newInstance(paramClass, conf); + param.readFields(dis); // try to read param data + + Call call = new Call(id, param, this); + + return call; + } catch (Exception e) { + throw new RuntimeException(e); + } finally { + IOUtils.closeStream(dis); + } + } + } + + /** + * Setup response for the IPC Call. + * + * @param response buffer to serialize the response into + * @param call {@link Call} to which we are setting up the response + * @param status {@link Status} of the IPC call + * @param rv return value for the IPC Call, if the call was successful + * @param errorClass error class, if the the call failed + * @param error error message, if the call failed + * @throws IOException + */ + private void setupResponse(ByteArrayOutputStream response, Call call, + Status status, Writable rv, String errorClass, String error) + throws IOException { + response.reset(); + DataOutputStream out = new DataOutputStream(response); + out.writeInt(call.id); // write call id + out.writeInt(status.state); // write status + + if (status == Status.SUCCESS) { + rv.write(out); + } else { + WritableUtils.writeString(out, errorClass); + WritableUtils.writeString(out, error); + } + call.setResponse(ByteBuffer.wrap(response.toByteArray())); + } + + /** + * This is a wrapper around {@link WritableByteChannel#write(ByteBuffer)}. If + * the amount of data is large, it writes to channel in smaller chunks. This + * is to avoid jdk from creating many direct buffers as the size of buffer + * increases. This also minimizes extra copies in NIO layer as a result of + * multiple write operations required to write a large buffer. + * + * @see WritableByteChannel#write(ByteBuffer) + * + * @param ctx + * @param buffer + */ + private void channelWrite(ChannelHandlerContext ctx, ByteBuffer buffer) { + try { + ByteBuf buf = ctx.alloc().buffer(); + buf.writeBytes(buffer.array()); + ctx.writeAndFlush(buf); + } catch (Throwable e) { + e.printStackTrace(); + } + } + + + /** A call queued for handling. */ + private static class Call { + private int id; // the client's call id + private Writable param; // the parameter passed + private ChannelInboundHandlerAdapter connection; // connection to client + private long timestamp; // the time received when response is null + // the time served when response is not null + private ByteBuffer response; // the response for this call + + /** + * + * @param id + * @param param + * @param connection + */ + public Call(int id, Writable param, ChannelInboundHandlerAdapter connection) { + this.id = id; + this.param = param; + this.connection = connection; + this.timestamp = System.currentTimeMillis(); + this.response = null; + } + + /** + * + */ + @Override + public String toString() { + return param.toString() + " from " + connection.toString(); + } + + /** + * + * @param response + */ + public void setResponse(ByteBuffer response) { + this.response = response; + } + } + + /** + * ExceptionsHandler manages Exception groups for special handling e.g., terse + * exception group for concise logging messages + */ + static class ExceptionsHandler { + private volatile Set terseExceptions = new HashSet(); + + /** + * Add exception class so server won't log its stack trace. Modifying the + * terseException through this method is thread safe. + * + * @param exceptionClass exception classes + */ + void addTerseExceptions(Class... exceptionClass) { + + // Make a copy of terseException for performing modification + final HashSet newSet = new HashSet(terseExceptions); + + // Add all class names into the HashSet + for (Class name : exceptionClass) { + newSet.add(name.toString()); + } + // Replace terseException set + terseExceptions = Collections.unmodifiableSet(newSet); + } + + /** + * + * @param t + * @return + */ + boolean isTerse(Class t) { + return terseExceptions.contains(t.toString()); + } + } + + /** + * Called for each call. + * + * @param protocol + * @param param + * @param receiveTime + * @return Writable + * @throws IOException + */ + public abstract Writable call(Class protocol, Writable param, + long receiveTime) throws IOException; +} Index: core/src/main/java/org/apache/hama/ipc/AsyncClient.java =================================================================== --- core/src/main/java/org/apache/hama/ipc/AsyncClient.java (리비전 0) +++ core/src/main/java/org/apache/hama/ipc/AsyncClient.java (리비전 0) @@ -0,0 +1,1197 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.hama.ipc; + +import io.netty.bootstrap.Bootstrap; +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufInputStream; +import io.netty.channel.Channel; +import io.netty.channel.ChannelFuture; +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.ChannelInboundHandlerAdapter; +import io.netty.channel.ChannelInitializer; +import io.netty.channel.ChannelOption; +import io.netty.channel.ChannelPipeline; +import io.netty.channel.EventLoopGroup; +import io.netty.channel.nio.NioEventLoopGroup; +import io.netty.channel.socket.SocketChannel; +import io.netty.channel.socket.nio.NioSocketChannel; +import io.netty.handler.logging.LogLevel; +import io.netty.handler.logging.LoggingHandler; +import io.netty.handler.ssl.SslContext; +import io.netty.handler.ssl.util.InsecureTrustManagerFactory; +import io.netty.handler.timeout.IdleState; +import io.netty.handler.timeout.IdleStateEvent; +import io.netty.handler.timeout.IdleStateHandler; + +import java.io.DataInputStream; +import java.io.IOException; +import java.net.ConnectException; +import java.net.InetSocketAddress; +import java.net.SocketTimeoutException; +import java.net.UnknownHostException; +import java.util.Hashtable; +import java.util.Iterator; +import java.util.Map.Entry; +import java.util.Set; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicBoolean; + +import javax.net.SocketFactory; + +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; +import org.apache.hadoop.conf.Configuration; +import org.apache.hadoop.io.DataOutputBuffer; +import org.apache.hadoop.io.IOUtils; +import org.apache.hadoop.io.Writable; +import org.apache.hadoop.io.WritableUtils; +import org.apache.hadoop.security.SaslRpcServer.AuthMethod; +import org.apache.hadoop.security.UserGroupInformation; +import org.apache.hadoop.util.ReflectionUtils; +import org.apache.hama.util.BSPNetUtils; + +/** + * A client for an IPC service using netty. IPC calls take a single + * {@link Writable} as a parameter, and return a {@link Writable} as their + * value. A service runs on a port and is defined by a parameter class and a + * value class. + * + * @see AsyncClient + */ +public class AsyncClient { + private static final String IPC_CLIENT_CONNECT_MAX_RETRIES_KEY = "ipc.client.connect.max.retries"; + private static final int IPC_CLIENT_CONNECT_MAX_RETRIES_DEFAULT = 10; + private static final Log LOG = LogFactory.getLog(AsyncClient.class); + private Hashtable connections = new Hashtable(); + + private Class valueClass; // class of call values + private int counter = 0; // counter for call ids + private AtomicBoolean running = new AtomicBoolean(true); // if client runs + final private Configuration conf; // configuration obj + + private SocketFactory socketFactory; // only use in order to meet the + // consistency with other clients + private int refCount = 1; + + final private static String PING_INTERVAL_NAME = "ipc.ping.interval"; + final static int DEFAULT_PING_INTERVAL = 60000; // 1 min + + /** + * set the ping interval value in configuration + * + * @param conf Configuration + * @param pingInterval the ping interval + */ + final public static void setPingInterval(Configuration conf, int pingInterval) { + conf.setInt(PING_INTERVAL_NAME, pingInterval); + } + + /** + * Get the ping interval from configuration; If not set in the configuration, + * return the default value. + * + * @param conf Configuration + * @return the ping interval + */ + final static int getPingInterval(Configuration conf) { + return conf.getInt(PING_INTERVAL_NAME, DEFAULT_PING_INTERVAL); + } + + /** + * The time after which a RPC will timeout. If ping is not enabled (via + * ipc.client.ping), then the timeout value is the same as the pingInterval. + * If ping is enabled, then there is no timeout value. + * + * @param conf Configuration + * @return the timeout period in milliseconds. -1 if no timeout value is set + */ + final public static int getTimeout(Configuration conf) { + if (!conf.getBoolean("ipc.client.ping", true)) { + return getPingInterval(conf); + } + return -1; + } + + /** + * Increment this client's reference count + * + */ + synchronized void incCount() { + refCount++; + } + + /** + * Decrement this client's reference count + * + */ + synchronized void decCount() { + refCount--; + } + + /** + * Return if this client has no reference + * + * @return true if this client has no reference; false otherwise + */ + synchronized boolean isZeroReference() { + return refCount == 0; + } + + /** + * Thread that reads responses and notifies callers. Each connection owns a + * socket connected to a remote address. Calls are multiplexed through this + * socket: responses may be delivered out of order. + */ + private class Connection { + private InetSocketAddress serverAddress; // server ip:port + private ConnectionHeader header; // connection header + private final ConnectionId remoteId; // connection id + private AuthMethod authMethod; // authentication method + + private boolean isSecurityEnabled; + + private EventLoopGroup group; + private Bootstrap bootstrap; + private Channel channel; + private SslContext sslCtx; + private int rpcTimeout; + private int maxIdleTime; // connections will be culled if it was idle + + private final RetryPolicy connectionRetryPolicy; + private boolean tcpNoDelay; // if T then disable Nagle's Algorithm + private int pingInterval; // how often sends ping to the server in msecs + + // currently active calls + private Hashtable calls = new Hashtable(); + private AtomicBoolean shouldCloseConnection = new AtomicBoolean(); // indicate + private IOException closeException; // if the connection is closed, close + // reason + + /** + * Setup Connection Configuration + * + * @param remoteId remote connection Id + * @throws IOException + */ + public Connection(ConnectionId remoteId) throws IOException { + group = new NioEventLoopGroup(); + bootstrap = new Bootstrap(); + this.remoteId = remoteId; + this.serverAddress = remoteId.getAddress(); + if (serverAddress.isUnresolved()) { + throw new UnknownHostException("unknown host: " + + remoteId.getAddress().getHostName()); + } + this.maxIdleTime = remoteId.getMaxIdleTime(); + this.connectionRetryPolicy = remoteId.connectionRetryPolicy; + this.tcpNoDelay = remoteId.getTcpNoDelay(); + this.pingInterval = remoteId.getPingInterval(); + if (LOG.isDebugEnabled()) { + LOG.debug("The ping interval is" + this.pingInterval + "ms."); + } + this.rpcTimeout = remoteId.getRpcTimeout(); + UserGroupInformation ticket = remoteId.getTicket(); + Class protocol = remoteId.getProtocol(); + + authMethod = AuthMethod.SIMPLE; + + header = new ConnectionHeader(protocol == null ? null + : protocol.getName(), ticket, authMethod); + + if (LOG.isDebugEnabled()) + LOG.debug("Use " + authMethod + " authentication"); + + } + + /** + * Add a call to this connection's call queue and notify a listener; + * synchronized. Returns false if called during shutdown. + * + * @param call to add + * @return true if the call was added. + */ + private synchronized boolean addCall(Call call) { + if (shouldCloseConnection.get()) + return false; + calls.put(call.id, call); + notify(); // is need? + return true; + } + + /** + * Update the server address if the address corresponding to the host name + * has changed. + */ + private synchronized boolean updateAddress() throws IOException { + // Do a fresh lookup with the old host name. + InetSocketAddress currentAddr = BSPNetUtils.makeSocketAddr( + serverAddress.getHostName(), serverAddress.getPort()); + + if (!serverAddress.equals(currentAddr)) { + LOG.warn("Address change detected. Old: " + serverAddress.toString() + + " New: " + currentAddr.toString()); + serverAddress = currentAddr; + return true; + } + return false; + } + + /** + * Connect to the server and set up the I/O streams. It then sends a header + * to the server. + */ + private synchronized void setupIOstreams() throws InterruptedException { + if (channel != null && channel.isActive()) { + return; + } + + try { + if (LOG.isDebugEnabled()) { + LOG.debug("Connecting to " + serverAddress); + } + + setupConnection(); + writeRpcHeader(channel); + writeHeader(channel); + } catch (Throwable t) { + if (t instanceof IOException) { + markClosed((IOException) t); + } else { + markClosed(new IOException("Couldn't set up IO streams", t)); + } + close(); + } + } + + /** + * Configure the client and connect to server + */ + private void setupConnection() throws Exception { + while (true) { + short ioFailures = 0; + short timeoutFailures = 0; + try { + isSecurityEnabled = UserGroupInformation.isSecurityEnabled(); + + if (isSecurityEnabled) { + sslCtx = SslContext + .newClientContext(InsecureTrustManagerFactory.INSTANCE); + } else { + sslCtx = null; + } + + // rpcTimeout overwrites pingInterval + if (rpcTimeout > 0) { + pingInterval = rpcTimeout; + } + + // Configure the client. + // NioEventLoopGroup is a multithreaded event loop that handles I/O + // operation + group = new NioEventLoopGroup(); + + // Bootstrap is a helper class that sets up a client + bootstrap = new Bootstrap(); + bootstrap.group(group).channel(NioSocketChannel.class) + .option(ChannelOption.TCP_NODELAY, tcpNoDelay) + .option(ChannelOption.CONNECT_TIMEOUT_MILLIS, pingInterval) + .handler(new LoggingHandler(LogLevel.INFO)) + .handler(new ChannelInitializer() { + @Override + public void initChannel(SocketChannel ch) throws Exception { + ChannelPipeline p = ch.pipeline(); + if (sslCtx != null) { + p.addLast(sslCtx.newHandler(ch.alloc(), + serverAddress.getHostName(), serverAddress.getPort())); + } + p.addLast(new IdleStateHandler(0, 0, maxIdleTime)); + // Register message processing handler + p.addLast(new NioClientInboundHandler()); + } + }); + + // Bind and start to accept incoming connections. + ChannelFuture channelFuture = bootstrap.connect( + serverAddress.getAddress(), serverAddress.getPort()).sync(); + + // Get io channel + channel = channelFuture.channel(); + break; + } catch (IOException toe) { + /* + * Check for an address change and update the local reference. Reset + * the failure counter if the address was changed + */ + if (updateAddress()) { + timeoutFailures = ioFailures = 0; + /* + * The max number of retries is 45, which amounts to 20s*45 = 15 + * minutes retries. + */ + handleConnectionFailure(timeoutFailures++, 45, toe); + } + } catch (Exception ie) { + if (updateAddress()) { + timeoutFailures = ioFailures = 0; + } + handleConnectionFailure(ioFailures++, ie); + } + } + } + + /** + * Write the rpc protocol header for each connection Out is not synchronized + * because only the first thread does this. + * + * @param channel + */ + private void writeRpcHeader(Channel channel) { + DataOutputBuffer buff = null; + try { + buff = new DataOutputBuffer(); + authMethod.write(buff); + ByteBuf buf = channel.alloc().buffer(); + buf.writeBytes(Server.HEADER.array()); + buf.writeByte(Server.CURRENT_VERSION); + buf.writeByte(buff.getData()[0]); + channel.write(buf); + } catch (Exception e) { + LOG.error("Couldn't send rpcheader" + e); + } finally { + IOUtils.closeStream(buff); + } + } + + /** + * Write the protocol header for each connection Out is not synchronized + * because only the first thread does this. + * + * @param channel + */ + private void writeHeader(Channel channel) { + DataOutputBuffer buff = null; + try { + buff = new DataOutputBuffer(); + header.write(buff); + byte[] data = buff.getData(); + int dataLength = buff.getLength(); + + ByteBuf buf1 = channel.alloc().buffer(); + buf1.writeInt(dataLength); + buf1.writeBytes(data, 0, dataLength); + channel.write(buf1); + } catch (Exception ioe) { + LOG.error("Couldn't send header"); + } finally { + IOUtils.closeStream(buff); + } + } + + /** + * close the current connection gracefully. + */ + private void closeConnection() { + try { + if (!this.group.isTerminated()) { + this.group.shutdownGracefully(); + LOG.info("client gracefully shutdown"); + } + } catch (Exception e) { + LOG.warn("Not able to close a client", e); + } + } + + /** + * This class process received response message from server. + */ + private class NioClientInboundHandler extends ChannelInboundHandlerAdapter { + + /** + * Receive a response. This method is called with the received response + * message, whenever new data is received from a server. + * + * @param ctx + * @param cause + */ + @Override + public void channelRead(ChannelHandlerContext ctx, Object msg) { + ByteBuf byteBuf = (ByteBuf) msg; + ByteBufInputStream byteBufInputStream = new ByteBufInputStream(byteBuf); + DataInputStream in = new DataInputStream(byteBufInputStream); + while (true) { + try { + if (in.available() <= 0) + break; + // try to read an id + int id = in.readInt(); + + if (LOG.isDebugEnabled()) + LOG.debug(serverAddress.getHostName() + " got value #" + id); + + Call call = calls.get(id); + + // read call status + int state = in.readInt(); + if (state == Status.SUCCESS.state) { + Writable value = ReflectionUtils.newInstance(valueClass, conf); + value.readFields(in); // read value + call.setValue(value); + calls.remove(id); + } else if (state == Status.ERROR.state) { + String className = WritableUtils.readString(in); + byte[] errorBytes = new byte[in.available()]; + in.readFully(errorBytes); + call.setException(new RemoteException(className, new String( + errorBytes))); + calls.remove(id); + } else if (state == Status.FATAL.state) { + // Close the connection + markClosed(new RemoteException(WritableUtils.readString(in), + WritableUtils.readString(in))); + } else { + byte[] garbageBytes = new byte[in.available()]; + in.readFully(garbageBytes); + } + } catch (IOException e) { + markClosed(e); + } + } + ((ByteBuf) msg).release(); + } + + /** + * Ths event handler method is called with a Throwable due to an I/O + * error. Then, exception is logged and its associated channel is closed + * here + * + * @param ctx + * @param cause + */ + @Override + public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) { + LOG.error("Occured I/O Error : " + cause.getMessage()); + ctx.close(); + } + + /** + * this method is triggered after a long reading/writing/idle time, it is + * marked as to be closed, or the client is marked as not running. + * + * @param ctx + * @param evt + */ + @Override + public void userEventTriggered(ChannelHandlerContext ctx, Object evt) { + if (evt instanceof IdleStateEvent) { + IdleStateEvent e = (IdleStateEvent) evt; + if (e.state() != IdleState.ALL_IDLE) { + if (!calls.isEmpty() && !shouldCloseConnection.get() + && running.get()) { + return; + } else if (shouldCloseConnection.get()) { + markClosed(null); + } else if (calls.isEmpty()) { // idle connection closed or stopped + markClosed(null); + } else { // get stopped but there are still pending requests + markClosed((IOException) new IOException() + .initCause(new InterruptedException())); + } + closeConnection(); + } + } + + } + + } + + /** + * Handle connection failures If the current number of retries is equal to + * the max number of retries, stop retrying and throw the exception; + * Otherwise backoff 1 second and try connecting again. This Method is only + * called from inside setupIOstreams(), which is synchronized. Hence the + * sleep is synchronized; the locks will be retained. + * + * @param curRetries current number of retries + * @param maxRetries max number of retries allowed + * @param ioe failure reason + * @throws IOException if max number of retries is reached + */ + private void handleConnectionFailure(int curRetries, int maxRetries, + IOException ioe) throws IOException { + + closeConnection(); + + // throw the exception if the maximum number of retries is reached + if (curRetries >= maxRetries) { + throw ioe; + } + + // otherwise back off and retry + try { + Thread.sleep(1000); + } catch (InterruptedException ignored) { + } + + LOG.info("Retrying connect to server: " + serverAddress + + ". Already tried " + curRetries + " time(s); maxRetries=" + + maxRetries); + } + + /* + * Handle connection failures If the current number of retries, stop + * retrying and throw the exception; Otherwise backoff 1 second and try + * connecting again. This Method is only called from inside + * setupIOstreams(), which is synchronized. Hence the sleep is synchronized; + * the locks will be retained. + * @param curRetries current number of retries + * @param ioe failure reason + * @throws Exception if max number of retries is reached + */ + private void handleConnectionFailure(int curRetries, Exception ioe) + throws Exception { + closeConnection(); + + final boolean retry; + try { + retry = connectionRetryPolicy.shouldRetry(ioe, curRetries); + } catch (Exception e) { + throw e instanceof IOException ? (IOException) e : new IOException(e); + } + if (!retry) { + throw ioe; + } + + LOG.info("Retrying connect to server: " + serverAddress + + ". Already tried " + curRetries + " time(s); retry policy is " + + connectionRetryPolicy); + } + + /** + * Return the remote address of server + * + * @return remote server address + */ + public InetSocketAddress getRemoteAddress() { + return serverAddress; + } + + /** + * Initiates a call by sending the parameter to the remote server. + * + * @param sendCall + */ + public void sendParam(Call sendCall) { + if (LOG.isDebugEnabled()) + LOG.debug(this.getClass().getName() + " sending #" + sendCall.id); + DataOutputBuffer buff = null; + try { + synchronized (channel) { + buff = new DataOutputBuffer(); + buff.writeInt(sendCall.id); + sendCall.param.write(buff); + byte[] data = buff.getData(); + int dataLength = buff.getLength(); + + ByteBuf buf = channel.alloc().buffer(); + + buf.writeInt(dataLength); + buf.writeBytes(data, 0, dataLength); + ChannelFuture channelFuture = channel.writeAndFlush(buf); + + if (channelFuture.cause() != null) { + throw channelFuture.cause(); + } + } + } catch (IOException ioe) { + markClosed(ioe); + } catch (Throwable t) { + markClosed(new IOException(t)); + } finally { + // the buffer is just an in-memory buffer, but it is still + // polite to close early + IOUtils.closeStream(buff); + } + } + + /** + * Mark the connection to be closed + * + * @param ioe + **/ + private synchronized void markClosed(IOException ioe) { + if (shouldCloseConnection.compareAndSet(false, true)) { + closeException = ioe; + notifyAll(); + } + } + + /** Close the connection. */ + private synchronized void close() { + if (!shouldCloseConnection.get()) { + LOG.error("The connection is not in the closed state"); + return; + } + + // release the resources + // first thing to do;take the connection out of the connection list + synchronized (connections) { + if (connections.get(remoteId) == this) { + Connection connection = connections.remove(remoteId); + connection.closeConnection(); + } + } + + // clean up all calls + if (closeException == null) { + if (!calls.isEmpty()) { + LOG.warn("A connection is closed for no cause and calls are not empty"); + + // clean up calls anyway + closeException = new IOException("Unexpected closed connection"); + cleanupCalls(); + } + } else { + // log the info + if (LOG.isDebugEnabled()) { + LOG.debug("closing ipc connection to " + serverAddress + ": " + + closeException.getMessage(), closeException); + } + + // cleanup calls + cleanupCalls(); + } + if (LOG.isDebugEnabled()) + LOG.debug(serverAddress.getHostName() + ": closed"); + } + + /** Cleanup all calls and mark them as done */ + private void cleanupCalls() { + Iterator> itor = calls.entrySet().iterator(); + while (itor.hasNext()) { + Call c = itor.next().getValue(); + c.setException(closeException); // local exception + itor.remove(); + } + } + } + + /** A call waiting for a value. */ + private class Call { + int id; // call id + Writable param; // parameter + Writable value; // value, null if error + IOException error; // exception, null if value + boolean done; // true when call is done + + protected Call(Writable param) { + this.param = param; + synchronized (AsyncClient.this) { + this.id = counter++; + } + } + + /** + * Indicate when the call is complete and the value or error are available. + * Notifies by default. + */ + protected synchronized void callComplete() { + this.done = true; + notify(); // notify caller + } + + /** + * Set the exception when there is an error. Notify the caller the call is + * done. + * + * @param error exception thrown by the call; either local or remote + */ + public synchronized void setException(IOException error) { + this.error = error; + this.callComplete(); + } + + /** + * Set the return value when there is no error. Notify the caller the call + * is done. + * + * @param value return value of the call. + */ + public synchronized void setValue(Writable value) { + this.value = value; + callComplete(); + } + } + + /** Call implementation used for parallel calls. */ + private class ParallelCall extends Call { + private ParallelResults results; + private int index; + + public ParallelCall(Writable param, ParallelResults results, int index) { + super(param); + this.results = results; + this.index = index; + } + + @Override + /** Deliver result to result collector. */ + protected void callComplete() { + results.callComplete(this); + } + } + + /** Result collector for parallel calls. */ + private static class ParallelResults { + private Writable[] values; + private int size; + private int count; + + public ParallelResults(int size) { + this.values = new Writable[size]; + this.size = size; + } + + /** + * Collect a result. + * + * @param call + */ + public synchronized void callComplete(ParallelCall call) { + values[call.index] = call.value; // store the value + count++; // count it + if (count == size) // if all values are in + notify(); // then notify waiting caller + } + } + + /** + * Construct an IPC client whose values are of the given {@link Writable} + * class. + * + * @param valueClass + * @param conf + * @param factory + */ + public AsyncClient(Class valueClass, Configuration conf, + SocketFactory factory) { + this.valueClass = valueClass; + this.conf = conf; + // SocketFactory only use in order to meet the consistency with other + // clients + this.socketFactory = factory; + } + + /** + * Construct an IPC client with the default SocketFactory + * + * @param valueClass + * @param conf + */ + public AsyncClient(Class valueClass, Configuration conf) { + // SocketFactory only use in order to meet the consistency with other + // clients + this(valueClass, conf, BSPNetUtils.getDefaultSocketFactory(conf)); + } + + /** + * Return the socket factory of this client + * + * @return this client's socket factory + */ + SocketFactory getSocketFactory() { + // SocketFactory only use in order to meet the consistency with other + // clients + return socketFactory; + } + + /** + * Stop all threads related to this client. No further calls may be made using + * this client. + */ + public void stop() { + if (LOG.isDebugEnabled()) { + LOG.debug("Stopping client"); + } + + if (!running.compareAndSet(true, false)) { + return; + } + + // wake up all connections + synchronized (connections) { + for (Connection conn : connections.values()) { + conn.closeConnection(); + } + } + } + + /** + * Make a call, passing param, to the IPC server running at + * address which is servicing the protocol protocol, + * with the ticket credentials, rpcTimeout as + * timeout and conf as configuration for this connection, + * returning the value. Throws exceptions if there are network problems or if + * the remote code threw an exception. + * + * @param param + * @param addr + * @param protocol + * @param ticket + * @param rpcTimeout + * @param conf + * @return Response Writable value + * @throws InterruptedException + * @throws IOException + */ + public Writable call(Writable param, InetSocketAddress addr, + Class protocol, UserGroupInformation ticket, int rpcTimeout, + Configuration conf) throws InterruptedException, IOException { + ConnectionId remoteId = ConnectionId.getConnectionId(addr, protocol, + ticket, rpcTimeout, conf); + return call(param, remoteId); + } + + /** + * Make a call, passing param, to the IPC server defined by + * remoteId, returning the value. Throws exceptions if there are + * network problems or if the remote code threw an exception. + * + * @param param + * @param remoteId + * @return Response Writable value + * @throws InterruptedException + * @throws IOException + */ + public Writable call(Writable param, ConnectionId remoteId) + throws InterruptedException, IOException { + Call call = new Call(param); + + Connection connection = getConnection(remoteId, call); + + connection.sendParam(call); // send the parameter + boolean interrupted = false; + + synchronized (call) { + while (!call.done) { + try { + call.wait(); // wait for the result + } catch (InterruptedException ie) { + interrupted = true; + } + } + + if (interrupted) { + // set the interrupt flag now that we are done waiting + Thread.currentThread().interrupt(); + + } + + if (call.error != null) { + if (call.error instanceof RemoteException) { + call.error.fillInStackTrace(); + throw call.error; + } else { // local exception + // use the connection because it will reflect an ip change, + // unlike + // the remoteId + throw wrapException(connection.getRemoteAddress(), call.error); + } + } else { + return call.value; + } + } + } + + /** + * Take an IOException and the address we were trying to connect to and return + * an IOException with the input exception as the cause. The new exception + * provides the stack trace of the place where the exception is thrown and + * some extra diagnostics information. If the exception is ConnectException or + * SocketTimeoutException, return a new one of the same type; Otherwise return + * an IOException. + * + * @param addr target address + * @param exception the relevant exception + * @return an exception to throw + */ + private IOException wrapException(InetSocketAddress addr, + IOException exception) { + if (exception instanceof ConnectException) { + // connection refused; include the host:port in the error + return (ConnectException) new ConnectException("Call to " + addr + + " failed on connection exception: " + exception) + .initCause(exception); + } else if (exception instanceof SocketTimeoutException) { + return (SocketTimeoutException) new SocketTimeoutException("Call to " + + addr + " failed on socket timeout exception: " + exception) + .initCause(exception); + } else { + return (IOException) new IOException("Call to " + addr + + " failed on local exception: " + exception).initCause(exception); + + } + } + + /** + * Makes a set of calls in parallel. Each parameter is sent to the + * corresponding address. When all values are available, or have timed out or + * errored, the collected results are returned in an array. The array contains + * nulls for calls that timed out or errored. + * + * @param params + * @param addresses + * @param protocol + * @param ticket + * @param conf + * @return Response Writable value array + * @throws IOException + * @throws InterruptedException + */ + public Writable[] call(Writable[] params, InetSocketAddress[] addresses, + Class protocol, UserGroupInformation ticket, Configuration conf) + throws IOException, InterruptedException { + if (addresses.length == 0) + return new Writable[0]; + + ParallelResults results = new ParallelResults(params.length); + ConnectionId remoteId[] = new ConnectionId[addresses.length]; + synchronized (results) { + for (int i = 0; i < params.length; i++) { + ParallelCall call = new ParallelCall(params[i], results, i); + try { + remoteId[i] = ConnectionId.getConnectionId(addresses[i], protocol, + ticket, 0, conf); + Connection connection = getConnection(remoteId[i], call); + connection.sendParam(call); // send each parameter + } catch (IOException e) { + // log errors + LOG.info("Calling " + addresses[i] + " caught: " + e.getMessage(), e); + results.size--; // wait for one fewer result + } + } + + while (results.count != results.size) { + try { + results.wait(); // wait for all results + } catch (InterruptedException e) { + } + } + + return results.values; + } + } + + // for unit testing only + Set getConnectionIds() { + synchronized (connections) { + return connections.keySet(); + } + } + + /** + * Get a connection from the pool, or create a new one and add it to the pool. + * Connections to a given ConnectionId are reused. + * + * @param remoteId + * @param call + * @return connection + * @throws IOException + * @throws InterruptedException + */ + private synchronized Connection getConnection(ConnectionId remoteId, Call call) + throws IOException, InterruptedException { + if (!running.get()) { + // the client is stopped + throw new IOException("The client is stopped"); + } + Connection connection; + /* + * we could avoid this allocation for each RPC by having a connectionsId + * object and with set() method. We need to manage the refs for keys in + * HashMap properly. For now its ok. + */ + do { + // synchronized (connections) { + connection = connections.get(remoteId); + if (connection == null) { + connection = new Connection(remoteId); + connections.put(remoteId, connection); + } else if (!connection.channel.isWritable() + || !connection.channel.isActive()) { + connection = new Connection(remoteId); + connections.remove(remoteId); + connections.put(remoteId, connection); + } + // } + } while (!connection.addCall(call)); + // we don't invoke the method below inside "synchronized (connections)" + // block above. The reason for that is if the server happens to be slow, + // it will take longer to establish a connection and that will slow the + // entire system down. + + connection.setupIOstreams(); + return connection; + } + + /** + * This class holds the address and the user ticket. The client connections to + * servers are uniquely identified by + */ + static class ConnectionId { + InetSocketAddress address; + UserGroupInformation ticket; + Class protocol; + private static final int PRIME = 16777619; + private int rpcTimeout; + private String serverPrincipal; + private int maxIdleTime; // connections will be culled if it was idle for + // maxIdleTime msecs + private final RetryPolicy connectionRetryPolicy; + private boolean tcpNoDelay; // if T then disable Nagle's Algorithm + private int pingInterval; // how often sends ping to the server in msecs + + ConnectionId(InetSocketAddress address, Class protocol, + UserGroupInformation ticket, int rpcTimeout, String serverPrincipal, + int maxIdleTime, RetryPolicy connectionRetryPolicy, boolean tcpNoDelay, + int pingInterval) { + this.protocol = protocol; + this.address = address; + this.ticket = ticket; + this.rpcTimeout = rpcTimeout; + this.serverPrincipal = serverPrincipal; + this.maxIdleTime = maxIdleTime; + this.connectionRetryPolicy = connectionRetryPolicy; + this.tcpNoDelay = tcpNoDelay; + this.pingInterval = pingInterval; + } + + InetSocketAddress getAddress() { + return address; + } + + Class getProtocol() { + return protocol; + } + + UserGroupInformation getTicket() { + return ticket; + } + + private int getRpcTimeout() { + return rpcTimeout; + } + + String getServerPrincipal() { + return serverPrincipal; + } + + int getMaxIdleTime() { + return maxIdleTime; + } + + boolean getTcpNoDelay() { + return tcpNoDelay; + } + + int getPingInterval() { + return pingInterval; + } + + static ConnectionId getConnectionId(InetSocketAddress addr, + Class protocol, UserGroupInformation ticket, Configuration conf) + throws IOException { + return getConnectionId(addr, protocol, ticket, 0, conf); + } + + static ConnectionId getConnectionId(InetSocketAddress addr, + Class protocol, UserGroupInformation ticket, int rpcTimeout, + Configuration conf) throws IOException { + return getConnectionId(addr, protocol, ticket, rpcTimeout, null, conf); + } + + static ConnectionId getConnectionId(InetSocketAddress addr, + Class protocol, UserGroupInformation ticket, int rpcTimeout, + RetryPolicy connectionRetryPolicy, Configuration conf) + throws IOException { + + if (connectionRetryPolicy == null) { + final int max = conf.getInt(IPC_CLIENT_CONNECT_MAX_RETRIES_KEY, + IPC_CLIENT_CONNECT_MAX_RETRIES_DEFAULT); + connectionRetryPolicy = RetryPolicies + .retryUpToMaximumCountWithFixedSleep(max, 1, TimeUnit.SECONDS); + } + + String remotePrincipal = getRemotePrincipal(conf, addr, protocol); + return new ConnectionId(addr, protocol, ticket, + rpcTimeout, + remotePrincipal, + conf.getInt("ipc.client.connection.maxidletime", 10000), // 10s + connectionRetryPolicy, + conf.getBoolean("ipc.client.tcpnodelay", false), + AsyncClient.getPingInterval(conf)); + } + + private static String getRemotePrincipal(Configuration conf, + InetSocketAddress address, Class protocol) throws IOException { + if (!UserGroupInformation.isSecurityEnabled() || protocol == null) { + return null; + } + return null; + } + + static boolean isEqual(Object a, Object b) { + return a == null ? b == null : a.equals(b); + } + + @Override + public boolean equals(Object obj) { + if (obj == this) { + return true; + } + if (obj instanceof ConnectionId) { + ConnectionId that = (ConnectionId) obj; + return isEqual(this.address, that.address) + && this.maxIdleTime == that.maxIdleTime + && isEqual(this.connectionRetryPolicy, that.connectionRetryPolicy) + && this.pingInterval == that.pingInterval + && isEqual(this.protocol, that.protocol) + && this.rpcTimeout == that.rpcTimeout + && isEqual(this.serverPrincipal, that.serverPrincipal) + && this.tcpNoDelay == that.tcpNoDelay + && isEqual(this.ticket, that.ticket); + } + return false; + } + + @Override + public int hashCode() { + int result = connectionRetryPolicy.hashCode(); + result = PRIME * result + ((address == null) ? 0 : address.hashCode()); + result = PRIME * result + maxIdleTime; + result = PRIME * result + pingInterval; + result = PRIME * result + ((protocol == null) ? 0 : protocol.hashCode()); + result = PRIME * rpcTimeout; + result = PRIME * result + + ((serverPrincipal == null) ? 0 : serverPrincipal.hashCode()); + result = PRIME * result + (tcpNoDelay ? 1231 : 1237); + result = PRIME * result + ((ticket == null) ? 0 : ticket.hashCode()); + return result; + } + } +} Index: core/src/main/java/org/apache/hama/ipc/AsyncRPC.java =================================================================== --- core/src/main/java/org/apache/hama/ipc/AsyncRPC.java (리비전 0) +++ core/src/main/java/org/apache/hama/ipc/AsyncRPC.java (리비전 0) @@ -0,0 +1,782 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.hama.ipc; + +import java.io.DataInput; +import java.io.DataOutput; +import java.io.IOException; +import java.lang.reflect.Array; +import java.lang.reflect.InvocationHandler; +import java.lang.reflect.InvocationTargetException; +import java.lang.reflect.Method; +import java.lang.reflect.Proxy; +import java.net.ConnectException; +import java.net.InetSocketAddress; +import java.net.SocketTimeoutException; +import java.util.HashMap; +import java.util.Map; + +import javax.net.SocketFactory; + +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; +import org.apache.hadoop.conf.Configurable; +import org.apache.hadoop.conf.Configuration; +import org.apache.hadoop.io.ObjectWritable; +import org.apache.hadoop.io.Text; +import org.apache.hadoop.io.Writable; +import org.apache.hadoop.security.SaslRpcServer; +import org.apache.hadoop.security.UserGroupInformation; +import org.apache.hadoop.security.token.SecretManager; +import org.apache.hadoop.security.token.TokenIdentifier; +import org.apache.hama.util.BSPNetUtils; + +/** + * A simple RPC mechanism using netty. + * + * A protocol is a Java interface. All parameters and return types must + * be one of: + * + *
    + *
  • a primitive type, boolean, byte, + * char, short, int, long, + * float, double, or void; or
  • + * + *
  • a {@link String}; or
  • + * + *
  • a {@link Writable}; or
  • + * + *
  • an array of the above types
  • + *
+ * + * All methods in the protocol should throw only IOException. No field data of + * the protocol instance is transmitted. + */ +public class AsyncRPC { + private static final Log LOG = LogFactory.getLog(AsyncRPC.class); + + private AsyncRPC() { + } // no public ctor + + /** A method invocation, including the method name and its parameters. */ + @SuppressWarnings("rawtypes") + private static class Invocation implements Writable, Configurable { + private String methodName; + private Class[] parameterClasses; + private Object[] parameters; + private Configuration conf; + + @SuppressWarnings("unused") + public Invocation() { + } + + /** + * + * @param method + * @param parameters + */ + public Invocation(Method method, Object[] parameters) { + this.methodName = method.getName(); + this.parameterClasses = method.getParameterTypes(); + this.parameters = parameters; + } + + /** The name of the method invoked. */ + public String getMethodName() { + return methodName; + } + + /** The parameter classes. */ + public Class[] getParameterClasses() { + return parameterClasses; + } + + /** The parameter instances. */ + public Object[] getParameters() { + return parameters; + } + + /** + * + * @param in + */ + public void readFields(DataInput in) throws IOException { + methodName = Text.readString(in); + parameters = new Object[in.readInt()]; + parameterClasses = new Class[parameters.length]; + ObjectWritable objectWritable = new ObjectWritable(); + for (int i = 0; i < parameters.length; i++) { + parameters[i] = ObjectWritable + .readObject(in, objectWritable, this.conf); + parameterClasses[i] = objectWritable.getDeclaredClass(); + } + } + + /** + * + * @param out + */ + public void write(DataOutput out) throws IOException { + Text.writeString(out, methodName); + out.writeInt(parameterClasses.length); + for (int i = 0; i < parameterClasses.length; i++) { + ObjectWritable.writeObject(out, parameters[i], parameterClasses[i], + conf); + } + } + + public String toString() { + StringBuffer buffer = new StringBuffer(); + buffer.append(methodName); + buffer.append("("); + for (int i = 0; i < parameters.length; i++) { + if (i != 0) + buffer.append(", "); + buffer.append(parameters[i]); + } + buffer.append(")"); + return buffer.toString(); + } + + public void setConf(Configuration conf) { + this.conf = conf; + } + + public Configuration getConf() { + return this.conf; + } + + } + + /** Cache a client using its socket factory as the hash key */ + static private class ClientCache { + private Map clients = new HashMap(); + + /** + * Construct & cache an IPC client with the user-provided SocketFactory if + * no cached client exists. + * + * @param conf Configuration + * @return an IPC client + */ + private synchronized AsyncClient getClient(Configuration conf, + SocketFactory factory) { + // Construct & cache client. The configuration is only used for timeout, + // and Clients have connection pools. So we can either (a) lose some + // connection pooling and leak sockets, or (b) use the same timeout for + // all + // configurations. Since the IPC is usually intended globally, not + // per-job, we choose (a). + AsyncClient client = clients.get(factory); + if (client == null) { + client = new AsyncClient(ObjectWritable.class, conf, factory); + clients.put(factory, client); + } else { + client.incCount(); + } + return client; + } + + /** + * Construct & cache an IPC client with the default SocketFactory if no + * cached client exists. + * + * @param conf Configuration + * @return an IPC client + */ + private synchronized AsyncClient getClient(Configuration conf) { + return getClient(conf, SocketFactory.getDefault()); + } + + /** + * Stop a RPCWithNetty client connection A RPCWithNetty client is closed + * only when its reference count becomes zero. + * + * @param client + */ + private void stopClient(AsyncClient client) { + synchronized (this) { + client.decCount(); + if (client.isZeroReference()) { + clients.remove(client.getSocketFactory()); + } + } + if (client.isZeroReference()) { + client.stop(); + } + } + } + + private static ClientCache CLIENTS = new ClientCache(); + + /** + * for unit testing only + * + * @param conf + * @return + */ + static AsyncClient getClient(Configuration conf) { + return CLIENTS.getClient(conf); + } + + /** + * + */ + private static class Invoker implements InvocationHandler { + private AsyncClient.ConnectionId remoteId; + private AsyncClient client; + private boolean isClosed = false; + + private Invoker(Class protocol, + InetSocketAddress address, UserGroupInformation ticket, + Configuration conf, SocketFactory factory, int rpcTimeout, + RetryPolicy connectionRetryPolicy) throws IOException { + this.remoteId = AsyncClient.ConnectionId.getConnectionId(address, protocol, + ticket, rpcTimeout, connectionRetryPolicy, conf); + this.client = CLIENTS.getClient(conf, factory); + } + + public Object invoke(Object proxy, Method method, Object[] args) + throws Throwable { + final boolean logDebug = LOG.isDebugEnabled(); + long startTime = 0; + if (logDebug) { + startTime = System.currentTimeMillis(); + } + + ObjectWritable value = (ObjectWritable) client.call(new Invocation( + method, args), remoteId); + if (logDebug) { + long callTime = System.currentTimeMillis() - startTime; + LOG.debug("Call: " + method.getName() + " " + callTime); + } + return value.get(); + } + + /** close the RPCWithNetty client that's responsible for this invoker's RPCs */ + synchronized private void close() { + if (!isClosed) { + isClosed = true; + CLIENTS.stopClient(client); + } + } + } + + /** + * A version mismatch for the RPC protocol. + */ + @SuppressWarnings("serial") + public static class VersionMismatch extends IOException { + private String interfaceName; + private long clientVersion; + private long serverVersion; + + /** + * Create a version mismatch exception + * + * @param interfaceName the name of the protocol mismatch + * @param clientVersion the client's version of the protocol + * @param serverVersion the server's version of the protocol + */ + public VersionMismatch(String interfaceName, long clientVersion, + long serverVersion) { + super("Protocol " + interfaceName + " version mismatch. (client = " + + clientVersion + ", server = " + serverVersion + ")"); + this.interfaceName = interfaceName; + this.clientVersion = clientVersion; + this.serverVersion = serverVersion; + } + + /** + * Get the interface name + * + * @return the java class name (eg. + * org.apache.hadoop.mapred.InterTrackerProtocol) + */ + public String getInterfaceName() { + return interfaceName; + } + + /** + * Get the client's preferred version + */ + public long getClientVersion() { + return clientVersion; + } + + /** + * Get the server's agreed to version. + */ + public long getServerVersion() { + return serverVersion; + } + } + + /** + * Get a proxy connection to a remote server + */ + public static VersionedProtocol waitForProxy( + Class protocol, long clientVersion, + InetSocketAddress addr, Configuration conf) throws IOException { + return waitForProxy(protocol, clientVersion, addr, conf, 0, Long.MAX_VALUE); + } + + /** + * Get a proxy connection to a remote server + * + * @param protocol protocol class + * @param clientVersion client version + * @param addr remote address + * @param conf configuration to use + * @param connTimeout time in milliseconds before giving up + * @return the proxy + * @throws IOException if the far end through a RemoteException + */ + static VersionedProtocol waitForProxy( + Class protocol, long clientVersion, + InetSocketAddress addr, Configuration conf, long connTimeout) + throws IOException { + return waitForProxy(protocol, clientVersion, addr, conf, 0, connTimeout); + } + + /** + * Get a proxy connection to a remote server + * + * @param protocol protocol class + * @param clientVersion client version + * @param addr remote address + * @param conf configuration to use + * @param rpcTimeout rpc timeout + * @param connTimeout time in milliseconds before giving up + * @return the proxy + * @throws IOException if the far end through a RemoteException + */ + static VersionedProtocol waitForProxy( + Class protocol, long clientVersion, + InetSocketAddress addr, Configuration conf, int rpcTimeout, + long connTimeout) throws IOException { + long startTime = System.currentTimeMillis(); + IOException ioe; + while (true) { + try { + return getProxy(protocol, clientVersion, addr, conf, rpcTimeout); + } catch (ConnectException se) { // namenode has not been started + LOG.info("Server at " + addr + " not available yet, Zzzzz..."); + ioe = se; + } catch (SocketTimeoutException te) { // namenode is busy + LOG.info("Problem connecting to server: " + addr); + ioe = te; + } + // check if timed out + if (System.currentTimeMillis() - connTimeout >= startTime) { + throw ioe; + } + + // wait for retry + try { + Thread.sleep(1000); + } catch (InterruptedException ie) { + // IGNORE + } + } + } + + /** + * Construct a client-side proxy object that implements the named protocol, + * talking to a server at the named address. + * + * @param protocol + * @param clientVersion + * @param addr + * @param conf + * @param factory + * @return the proxy + * @throws IOException + */ + public static VersionedProtocol getProxy( + Class protocol, long clientVersion, + InetSocketAddress addr, Configuration conf, SocketFactory factory) + throws IOException { + UserGroupInformation ugi = UserGroupInformation.getCurrentUser(); + return getProxy(protocol, clientVersion, addr, ugi, conf, factory, 0); + } + + /** + * Construct a client-side proxy object that implements the named protocol, + * talking to a server at the named address. + * + * @param protocol + * @param clientVersion + * @param addr + * @param conf + * @param factory + * @param rpcTimeout + * @return the proxy + * @throws IOException + */ + public static VersionedProtocol getProxy( + Class protocol, long clientVersion, + InetSocketAddress addr, Configuration conf, SocketFactory factory, + int rpcTimeout) throws IOException { + UserGroupInformation ugi = UserGroupInformation.getCurrentUser(); + return getProxy(protocol, clientVersion, addr, ugi, conf, factory, + rpcTimeout); + } + + /** + * Construct a client-side proxy object that implements the named protocol, + * talking to a server at the named address. + * + * @param protocol + * @param clientVersion + * @param addr + * @param ticket + * @param conf + * @param factory + * @return the proxy + * @throws IOException + */ + public static VersionedProtocol getProxy( + Class protocol, long clientVersion, + InetSocketAddress addr, UserGroupInformation ticket, Configuration conf, + SocketFactory factory) throws IOException { + return getProxy(protocol, clientVersion, addr, ticket, conf, factory, 0); + } + + /** + * Construct a client-side proxy object that implements the named protocol, + * talking to a server at the named address. + * + * @param protocol + * @param clientVersion + * @param addr + * @param ticket + * @param conf + * @param factory + * @param rpcTimeout + * @return the proxy + * @throws IOException + */ + public static VersionedProtocol getProxy( + Class protocol, long clientVersion, + InetSocketAddress addr, UserGroupInformation ticket, Configuration conf, + SocketFactory factory, int rpcTimeout) throws IOException { + return getProxy(protocol, clientVersion, addr, ticket, conf, factory, + rpcTimeout, null, true); + } + + /** + * Construct a client-side proxy object that implements the named protocol, + * talking to a server at the named address. + * + * @param protocol + * @param clientVersion + * @param addr + * @param ticket + * @param conf + * @param factory + * @param rpcTimeout + * @param connectionRetryPolicy + * @param checkVersion + * @return the proxy + * @throws IOException + */ + public static VersionedProtocol getProxy( + Class protocol, long clientVersion, + InetSocketAddress addr, UserGroupInformation ticket, Configuration conf, + SocketFactory factory, int rpcTimeout, RetryPolicy connectionRetryPolicy, + boolean checkVersion) throws IOException { + + if (UserGroupInformation.isSecurityEnabled()) { + SaslRpcServer.init(conf); + } + final Invoker invoker = new Invoker(protocol, addr, ticket, conf, factory, + rpcTimeout, connectionRetryPolicy); + VersionedProtocol proxy = (VersionedProtocol) Proxy.newProxyInstance( + protocol.getClassLoader(), new Class[] { protocol }, invoker); + + if (checkVersion) { + checkVersion(protocol, clientVersion, proxy); + } + return proxy; + } + + /** + * Get server version and then compare it with client version. + * + * @param protocol + * @param clientVersion + * @param proxy + * @throws IOException + */ + public static void checkVersion(Class protocol, + long clientVersion, VersionedProtocol proxy) throws IOException { + long serverVersion = proxy.getProtocolVersion(protocol.getName(), + clientVersion); + if (serverVersion != clientVersion) { + throw new VersionMismatch(protocol.getName(), clientVersion, + serverVersion); + } + } + + /** + * Construct a client-side proxy object with the default SocketFactory + * + * @param protocol + * @param clientVersion + * @param addr + * @param conf + * @return a proxy instance + * @throws IOException + */ + public static VersionedProtocol getProxy( + Class protocol, long clientVersion, + InetSocketAddress addr, Configuration conf) throws IOException { + return getProxy(protocol, clientVersion, addr, conf, + BSPNetUtils.getDefaultSocketFactory(conf), 0); + } + + /** + * Get VersionedProtocol + * + * @param protocol + * @param clientVersion + * @param addr + * @param conf + * @param rpcTimeout + * @return the proxy + * @throws IOException + */ + public static VersionedProtocol getProxy( + Class protocol, long clientVersion, + InetSocketAddress addr, Configuration conf, int rpcTimeout) + throws IOException { + + return getProxy(protocol, clientVersion, addr, conf, + BSPNetUtils.getDefaultSocketFactory(conf), rpcTimeout); + } + + /** + * Stop this proxy and release its invoker's resource + * + * @param proxy the proxy to be stopped + */ + public static void stopProxy(VersionedProtocol proxy) { + if (proxy != null) { + ((Invoker) Proxy.getInvocationHandler(proxy)).close(); + } + } + + /** + * Expert: Make multiple, parallel calls to a set of servers. + * + * @param method + * @param params + * @param addrs + * @param ticket + * @param conf + * @return response object array + * @throws IOException + * @throws InterruptedException + */ + public static Object[] call(Method method, Object[][] params, + InetSocketAddress[] addrs, UserGroupInformation ticket, Configuration conf) + throws IOException, InterruptedException { + + Invocation[] invocations = new Invocation[params.length]; + for (int i = 0; i < params.length; i++) + invocations[i] = new Invocation(method, params[i]); + AsyncClient client = CLIENTS.getClient(conf); + try { + Writable[] wrappedValues = client.call(invocations, addrs, + method.getDeclaringClass(), ticket, conf); + + if (method.getReturnType() == Void.TYPE) { + return null; + } + + Object[] values = (Object[]) Array.newInstance(method.getReturnType(), + wrappedValues.length); + for (int i = 0; i < values.length; i++) + if (wrappedValues[i] != null) + values[i] = ((ObjectWritable) wrappedValues[i]).get(); + + return values; + } finally { + CLIENTS.stopClient(client); + } + } + + /** + * Construct a server for a protocol implementation instance listening on a + * port and address. + * + * @param instance + * @param bindAddress + * @param port + * @param conf + * @return server + * @throws IOException + */ + public static NioServer getServer(final Object instance, + final String bindAddress, final int port, Configuration conf) + throws IOException { + return getServer(instance, bindAddress, port, 1, false, conf); + } + + /** + * Construct a server for a protocol implementation instance listening on a + * port and address. + * + * @param instance + * @param bindAddress + * @param port + * @param numHandlers + * @param verbose + * @param conf + * @return server + * @throws IOException + */ + public static NioServer getServer(final Object instance, + final String bindAddress, final int port, final int numHandlers, + final boolean verbose, Configuration conf) throws IOException { + return getServer(instance, bindAddress, port, numHandlers, verbose, conf, + null); + } + + /** + * Construct a server for a protocol implementation instance listening on a + * port and address, with a secret manager. + * + * @param instance + * @param bindAddress + * @param port + * @param numHandlers + * @param verbose + * @param conf + * @param secretManager + * @return server + * @throws IOException + */ + public static NioServer getServer(final Object instance, + final String bindAddress, final int port, final int numHandlers, + final boolean verbose, Configuration conf, + SecretManager secretManager) + throws IOException { + return new NioServer(instance, conf, bindAddress, port, numHandlers, + verbose, secretManager); + } + + /** An RPC Server. */ + public static class NioServer extends org.apache.hama.ipc.AsyncServer { + private Object instance; + private boolean verbose; + + /** + * Construct an RPC server. + * + * @param instance the instance whose methods will be called + * @param conf the configuration to use + * @param bindAddress the address to bind on to listen for connection + * @param port the port to listen for connections on + * @throws IOException + */ + public NioServer(Object instance, Configuration conf, String bindAddress, + int port) throws IOException { + this(instance, conf, bindAddress, port, 1, false, null); + } + + private static String classNameBase(String className) { + String[] names = className.split("\\.", -1); + if (names == null || names.length == 0) { + return className; + } + return names[names.length - 1]; + } + + /** + * Construct an RPC server. + * + * @param instance the instance whose methods will be called + * @param conf the configuration to use + * @param bindAddress the address to bind on to listen for connection + * @param port the port to listen for connections on + * @param numHandlers the number of method handler threads to run + * @param verbose whether each call should be logged + * @throws IOException + */ + public NioServer(Object instance, Configuration conf, String bindAddress, + int port, int numHandlers, boolean verbose, + SecretManager secretManager) + throws IOException { + super(bindAddress, port, Invocation.class, numHandlers, conf, + classNameBase(instance.getClass().getName()), secretManager); + this.instance = instance; + this.verbose = verbose; + } + + public Writable call(Class protocol, Writable param, long receivedTime) + throws IOException { + try { + Invocation call = (Invocation) param; + if (verbose) + log("Call: " + call); + + Method method = protocol.getMethod(call.getMethodName(), + call.getParameterClasses()); + method.setAccessible(true); + + long startTime = System.currentTimeMillis(); + Object value = method.invoke(instance, call.getParameters()); + int processingTime = (int) (System.currentTimeMillis() - startTime); + int qTime = (int) (startTime - receivedTime); + if (LOG.isDebugEnabled()) { + LOG.debug("Served: " + call.getMethodName() + " queueTime= " + qTime + + " procesingTime= " + processingTime); + } + if (verbose) + log("Return: " + value); + + return new ObjectWritable(method.getReturnType(), value); + + } catch (InvocationTargetException e) { + Throwable target = e.getTargetException(); + if (target instanceof IOException) { + throw (IOException) target; + } else { + IOException ioe = new IOException(target.toString()); + ioe.setStackTrace(target.getStackTrace()); + throw ioe; + } + } catch (Throwable e) { + if (!(e instanceof IOException)) { + LOG.error("Unexpected throwable object ", e); + } + IOException ioe = new IOException(e.toString()); + ioe.setStackTrace(e.getStackTrace()); + throw ioe; + } + } + } + + private static void log(String value) { + if (value != null && value.length() > 55) + value = value.substring(0, 55) + "..."; + LOG.info(value); + } +} Index: core/src/main/java/org/apache/hama/bsp/message/HamaAsyncMessageManagerImpl.java =================================================================== --- core/src/main/java/org/apache/hama/bsp/message/HamaAsyncMessageManagerImpl.java (리비전 0) +++ core/src/main/java/org/apache/hama/bsp/message/HamaAsyncMessageManagerImpl.java (리비전 0) @@ -0,0 +1,172 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.hama.bsp.message; + +import java.io.IOException; +import java.net.BindException; +import java.net.InetSocketAddress; +import java.util.Map; + +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; +import org.apache.hadoop.conf.Configuration; +import org.apache.hadoop.io.Writable; +import org.apache.hama.HamaConfiguration; +import org.apache.hama.bsp.BSPMessageBundle; +import org.apache.hama.bsp.BSPPeer; +import org.apache.hama.bsp.BSPPeerImpl; +import org.apache.hama.bsp.TaskAttemptID; +import org.apache.hama.ipc.HamaRPCProtocolVersion; +import org.apache.hama.ipc.AsyncRPC; +import org.apache.hama.ipc.AsyncServer; +import org.apache.hama.util.LRUCache; + +/** + * Implementation of the {@link HamaMessageManager}. + * + */ +public final class HamaAsyncMessageManagerImpl extends + AbstractMessageManager implements HamaMessageManager { + + private static final Log LOG = LogFactory + .getLog(HamaAsyncMessageManagerImpl.class); + + private AsyncServer server; + + private LRUCache> peersLRUCache = null; + + @SuppressWarnings("serial") + @Override + public final void init(TaskAttemptID attemptId, BSPPeer peer, + HamaConfiguration conf, InetSocketAddress peerAddress) { + super.init(attemptId, peer, conf, peerAddress); + startRPCServer(conf, peerAddress); + peersLRUCache = new LRUCache>( + maxCachedConnections) { + @Override + protected final boolean removeEldestEntry( + Map.Entry> eldest) { + if (size() > this.capacity) { + HamaMessageManager proxy = eldest.getValue(); + AsyncRPC.stopProxy(proxy); + return true; + } + return false; + } + }; + } + + private final void startRPCServer(Configuration conf, + InetSocketAddress peerAddress) { + try { + startServer(peerAddress.getHostName(), peerAddress.getPort()); + } catch (IOException ioe) { + LOG.error("Fail to start RPC server!", ioe); + throw new RuntimeException("RPC Server could not be launched!"); + } + } + + private void startServer(String hostName, int port) throws IOException { + int retry = 0; + try { + this.server = AsyncRPC.getServer(this, hostName, port, + conf.getInt("hama.default.messenger.handler.threads.num", 5), false, + conf); + + server.start(); + LOG.info("BSPPeer address:" + server.getAddress().getHostName() + + " port:" + server.getAddress().getPort()); + } catch (BindException e) { + LOG.warn("Address already in use. Retrying " + hostName + ":" + port + 1); + startServer(hostName, port + 1); + retry++; + + if (retry > 5) { + throw new RuntimeException("RPC Server could not be launched!"); + } + } + } + + @Override + public final void close() { + super.close(); + if (server != null) { + server.stop(); + } + } + + @Override + public final void transfer(InetSocketAddress addr, BSPMessageBundle bundle) + throws IOException { + HamaMessageManager bspPeerConnection = this.getBSPPeerConnection(addr); + if (bspPeerConnection == null) { + throw new IllegalArgumentException("Can not find " + addr.toString() + + " to transfer messages to!"); + } else { + peer.incrementCounter(BSPPeerImpl.PeerCounter.MESSAGE_BYTES_TRANSFERED, bundle.getLength()); + bspPeerConnection.put(bundle); + } + } + + /** + * @param addr socket address to which BSP Peer Connection will be + * established + * @return BSP Peer Connection, tried to return cached connection, else + * returns a new connection and caches it + * @throws IOException + */ + @SuppressWarnings("unchecked") + protected final HamaMessageManager getBSPPeerConnection( + InetSocketAddress addr) throws IOException { + HamaMessageManager bspPeerConnection; + if (!peersLRUCache.containsKey(addr)) { + bspPeerConnection = (HamaMessageManager) AsyncRPC.getProxy( + HamaMessageManager.class, HamaRPCProtocolVersion.versionID, addr, + this.conf); + peersLRUCache.put(addr, bspPeerConnection); + } else { + bspPeerConnection = peersLRUCache.get(addr); + } + return bspPeerConnection; + } + + @Override + public final void put(M msg) throws IOException { + loopBackMessage(msg); + } + + @Override + public final void put(BSPMessageBundle bundle) throws IOException { + loopBackBundle(bundle); + } + + @Override + public final long getProtocolVersion(String arg0, long arg1) + throws IOException { + return versionID; + } + + @Override + public InetSocketAddress getListenerAddress() { + if (this.server != null) { + return this.server.getAddress(); + } + return null; + } + +} Index: core/pom.xml =================================================================== --- core/pom.xml (리비전 1610889) +++ core/pom.xml (작업 사본) @@ -135,6 +135,11 @@ org.apache.zookeeper zookeeper + + io.netty + netty-all + 4.0.20.Final +