Index: core/src/main/java/org/apache/hama/util/BSPNetUtils.java =================================================================== --- core/src/main/java/org/apache/hama/util/BSPNetUtils.java (revision 1535330) +++ core/src/main/java/org/apache/hama/util/BSPNetUtils.java (working copy) @@ -44,6 +44,9 @@ import org.apache.hadoop.net.NetUtils; import org.apache.hadoop.security.SecurityUtil; import org.apache.hama.Constants; +import org.apache.hama.commons.io.SocketIOWithTimeout; +import org.apache.hama.commons.io.SocketInputStream; +import org.apache.hama.commons.io.SocketOutputStream; import org.apache.hama.ipc.Server; /** Index: core/src/main/java/org/apache/hama/util/SocketIOWithTimeout.java =================================================================== --- core/src/main/java/org/apache/hama/util/SocketIOWithTimeout.java (revision 1535330) +++ core/src/main/java/org/apache/hama/util/SocketIOWithTimeout.java (working copy) @@ -1,453 +0,0 @@ -/** - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.hama.util; - -import java.io.IOException; -import java.io.InterruptedIOException; -import java.net.SocketAddress; -import java.net.SocketTimeoutException; -import java.nio.ByteBuffer; -import java.nio.channels.SelectableChannel; -import java.nio.channels.SelectionKey; -import java.nio.channels.Selector; -import java.nio.channels.SocketChannel; -import java.nio.channels.spi.SelectorProvider; -import java.util.Iterator; -import java.util.LinkedList; - -import org.apache.commons.logging.Log; -import org.apache.commons.logging.LogFactory; -import org.apache.hadoop.util.StringUtils; - -/** - * This supports input and output streams for a socket channels. These streams - * can have a timeout. - */ -abstract class SocketIOWithTimeout { - // This is intentionally package private. - - static final Log LOG = LogFactory.getLog(SocketIOWithTimeout.class); - - private SelectableChannel channel; - private long timeout; - private boolean closed = false; - - private static SelectorPool selector = new SelectorPool(); - - /* - * A timeout value of 0 implies wait for ever. We should have a value of - * timeout that implies zero wait.. i.e. read or write returns immediately. - * This will set channel to non-blocking. - */ - SocketIOWithTimeout(SelectableChannel channel, long timeout) - throws IOException { - checkChannelValidity(channel); - - this.channel = channel; - this.timeout = timeout; - // Set non-blocking - channel.configureBlocking(false); - } - - void close() { - closed = true; - } - - boolean isOpen() { - return !closed && channel.isOpen(); - } - - SelectableChannel getChannel() { - return channel; - } - - /** - * Utility function to check if channel is ok. Mainly to throw IOException - * instead of runtime exception in case of mismatch. This mismatch can occur - * for many runtime reasons. - */ - static void checkChannelValidity(Object channel) throws IOException { - if (channel == null) { - /* - * Most common reason is that original socket does not have a channel. So - * making this an IOException rather than a RuntimeException. - */ - throw new IOException("Channel is null. Check " - + "how the channel or socket is created."); - } - - if (!(channel instanceof SelectableChannel)) { - throw new IOException("Channel should be a SelectableChannel"); - } - } - - /** - * Performs actual IO operations. This is not expected to block. - * - * @param buf - * @return number of bytes (or some equivalent). 0 implies underlying channel - * is drained completely. We will wait if more IO is required. - * @throws IOException - */ - abstract int performIO(ByteBuffer buf) throws IOException; - - /** - * Performs one IO and returns number of bytes read or written. It waits up to - * the specified timeout. If the channel is not read before the timeout, - * SocketTimeoutException is thrown. - * - * @param buf buffer for IO - * @param ops Selection Ops used for waiting. Suggested values: - * SelectionKey.OP_READ while reading and SelectionKey.OP_WRITE while - * writing. - * - * @return number of bytes read or written. negative implies end of stream. - * @throws IOException - */ - int doIO(ByteBuffer buf, int ops) throws IOException { - - /* - * For now only one thread is allowed. If user want to read or write from - * multiple threads, multiple streams could be created. In that case - * multiple threads work as well as underlying channel supports it. - */ - if (!buf.hasRemaining()) { - throw new IllegalArgumentException("Buffer has no data left."); - // or should we just return 0? - } - - while (buf.hasRemaining()) { - if (closed) { - return -1; - } - - try { - int n = performIO(buf); - if (n != 0) { - // successful io or an error. - return n; - } - } catch (IOException e) { - if (!channel.isOpen()) { - closed = true; - } - throw e; - } - - // now wait for socket to be ready. - int count = 0; - try { - count = selector.select(channel, ops, timeout); - } catch (IOException e) { // unexpected IOException. - closed = true; - throw e; - } - - if (count == 0) { - throw new SocketTimeoutException(timeoutExceptionString(channel, - timeout, ops)); - } - // otherwise the socket should be ready for io. - } - - return 0; // does not reach here. - } - - /** - * The contract is similar to {@link SocketChannel#connect(SocketAddress)} - * with a timeout. - * - * @see SocketChannel#connect(SocketAddress) - * - * @param channel - this should be a {@link SelectableChannel} - * @param endpoint - * @throws IOException - */ - static void connect(SocketChannel channel, SocketAddress endpoint, int timeout) - throws IOException { - - boolean blockingOn = channel.isBlocking(); - if (blockingOn) { - channel.configureBlocking(false); - } - - try { - if (channel.connect(endpoint)) { - return; - } - - long timeoutLeft = timeout; - long endTime = (timeout > 0) ? (System.currentTimeMillis() + timeout) : 0; - - while (true) { - // we might have to call finishConnect() more than once - // for some channels (with user level protocols) - - int ret = selector.select((SelectableChannel) channel, - SelectionKey.OP_CONNECT, timeoutLeft); - - if (ret > 0 && channel.finishConnect()) { - return; - } - - if (ret == 0 - || (timeout > 0 && (timeoutLeft = (endTime - System - .currentTimeMillis())) <= 0)) { - throw new SocketTimeoutException(timeoutExceptionString(channel, - timeout, SelectionKey.OP_CONNECT)); - } - } - } catch (IOException e) { - // javadoc for SocketChannel.connect() says channel should be closed. - try { - channel.close(); - } catch (IOException ignored) { - } - throw e; - } finally { - if (blockingOn && channel.isOpen()) { - channel.configureBlocking(true); - } - } - } - - /** - * This is similar to {@link #doIO(ByteBuffer, int)} except that it does not - * perform any I/O. It just waits for the channel to be ready for I/O as - * specified in ops. - * - * @param ops Selection Ops used for waiting - * - * @throws SocketTimeoutException if select on the channel times out. - * @throws IOException if any other I/O error occurs. - */ - void waitForIO(int ops) throws IOException { - - if (selector.select(channel, ops, timeout) == 0) { - throw new SocketTimeoutException(timeoutExceptionString(channel, timeout, - ops)); - } - } - - private static String timeoutExceptionString(SelectableChannel channel, - long timeout, int ops) { - - String waitingFor; - switch (ops) { - - case SelectionKey.OP_READ: - waitingFor = "read"; - break; - - case SelectionKey.OP_WRITE: - waitingFor = "write"; - break; - - case SelectionKey.OP_CONNECT: - waitingFor = "connect"; - break; - - default: - waitingFor = "" + ops; - } - - return timeout + " millis timeout while " - + "waiting for channel to be ready for " + waitingFor + ". ch : " - + channel; - } - - /** - * This maintains a pool of selectors. These selectors are closed once they - * are idle (unused) for a few seconds. - */ - private static class SelectorPool { - - private static class SelectorInfo { - Selector selector; - long lastActivityTime; - LinkedList queue; - - void close() { - if (selector != null) { - try { - selector.close(); - } catch (IOException e) { - LOG.warn("Unexpected exception while closing selector : " - + StringUtils.stringifyException(e)); - } - } - } - } - - private static class ProviderInfo { - SelectorProvider provider; - LinkedList queue; // lifo - ProviderInfo next; - } - - private static final long IDLE_TIMEOUT = 10 * 1000; // 10 seconds. - - private ProviderInfo providerList = null; - - /** - * Waits on the channel with the given timeout using one of the cached - * selectors. It also removes any cached selectors that are idle for a few - * seconds. - * - * @param channel - * @param ops - * @param timeout - * @return - * @throws IOException - */ - int select(SelectableChannel channel, int ops, long timeout) - throws IOException { - - SelectorInfo info = get(channel); - - SelectionKey key = null; - int ret = 0; - - try { - while (true) { - long start = (timeout == 0) ? 0 : System.currentTimeMillis(); - - key = channel.register(info.selector, ops); - ret = info.selector.select(timeout); - - if (ret != 0) { - return ret; - } - - /* - * Sometimes select() returns 0 much before timeout for unknown - * reasons. So select again if required. - */ - if (timeout > 0) { - timeout -= System.currentTimeMillis() - start; - if (timeout <= 0) { - return 0; - } - } - - if (Thread.currentThread().isInterrupted()) { - throw new InterruptedIOException("Interruped while waiting for " - + "IO on channel " + channel + ". " + timeout - + " millis timeout left."); - } - } - } finally { - if (key != null) { - key.cancel(); - } - - // clear the canceled key. - try { - info.selector.selectNow(); - } catch (IOException e) { - LOG.info("Unexpected Exception while clearing selector : " - + StringUtils.stringifyException(e)); - // don't put the selector back. - info.close(); - return ret; - } - - release(info); - } - } - - /** - * Takes one selector from end of LRU list of free selectors. If there are - * no selectors awailable, it creates a new selector. Also invokes - * trimIdleSelectors(). - * - * @param channel - * @return - * @throws IOException - */ - private synchronized SelectorInfo get(SelectableChannel channel) - throws IOException { - SelectorInfo selInfo = null; - - SelectorProvider provider = channel.provider(); - - // pick the list : rarely there is more than one provider in use. - ProviderInfo pList = providerList; - while (pList != null && pList.provider != provider) { - pList = pList.next; - } - if (pList == null) { - // LOG.info("Creating new ProviderInfo : " + provider.toString()); - pList = new ProviderInfo(); - pList.provider = provider; - pList.queue = new LinkedList(); - pList.next = providerList; - providerList = pList; - } - - LinkedList queue = pList.queue; - - if (queue.isEmpty()) { - Selector selector = provider.openSelector(); - selInfo = new SelectorInfo(); - selInfo.selector = selector; - selInfo.queue = queue; - } else { - selInfo = queue.removeLast(); - } - - trimIdleSelectors(System.currentTimeMillis()); - return selInfo; - } - - /** - * puts selector back at the end of LRU list of free selectos. Also invokes - * trimIdleSelectors(). - * - * @param info - */ - private synchronized void release(SelectorInfo info) { - long now = System.currentTimeMillis(); - trimIdleSelectors(now); - info.lastActivityTime = now; - info.queue.addLast(info); - } - - /** - * Closes selectors that are idle for IDLE_TIMEOUT (10 sec). It does not - * traverse the whole list, just over the one that have crossed the timeout. - */ - private void trimIdleSelectors(long now) { - long cutoff = now - IDLE_TIMEOUT; - - for (ProviderInfo pList = providerList; pList != null; pList = pList.next) { - if (pList.queue.isEmpty()) { - continue; - } - for (Iterator it = pList.queue.iterator(); it.hasNext();) { - SelectorInfo info = it.next(); - if (info.lastActivityTime > cutoff) { - break; - } - it.remove(); - info.close(); - } - } - } - } -} Index: core/src/main/java/org/apache/hama/util/SocketInputStream.java =================================================================== --- core/src/main/java/org/apache/hama/util/SocketInputStream.java (revision 1535330) +++ core/src/main/java/org/apache/hama/util/SocketInputStream.java (working copy) @@ -1,169 +0,0 @@ -/** - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.hama.util; - -import java.io.IOException; -import java.io.InputStream; -import java.net.Socket; -import java.net.SocketTimeoutException; -import java.nio.ByteBuffer; -import java.nio.channels.FileChannel; -import java.nio.channels.ReadableByteChannel; -import java.nio.channels.SelectableChannel; -import java.nio.channels.SelectionKey; - -/** - * This implements an input stream that can have a timeout while reading. This - * sets non-blocking flag on the socket channel. So after create this object, - * read() on {@link Socket#getInputStream()} and write() on - * {@link Socket#getOutputStream()} for the associated socket will throw - * IllegalBlockingModeException. Please use {@link SocketOutputStream} for - * writing. - */ -public class SocketInputStream extends InputStream implements - ReadableByteChannel { - - private Reader reader; - - private static class Reader extends SocketIOWithTimeout { - ReadableByteChannel channel; - - Reader(ReadableByteChannel channel, long timeout) throws IOException { - super((SelectableChannel) channel, timeout); - this.channel = channel; - } - - int performIO(ByteBuffer buf) throws IOException { - return channel.read(buf); - } - } - - /** - * Create a new input stream with the given timeout. If the timeout is zero, - * it will be treated as infinite timeout. The socket's channel will be - * configured to be non-blocking. - * - * @param channel Channel for reading, should also be a - * {@link SelectableChannel}. The channel will be configured to be - * non-blocking. - * @param timeout timeout in milliseconds. must not be negative. - * @throws IOException - */ - public SocketInputStream(ReadableByteChannel channel, long timeout) - throws IOException { - SocketIOWithTimeout.checkChannelValidity(channel); - reader = new Reader(channel, timeout); - } - - /** - * Same as SocketInputStream(socket.getChannel(), timeout):
- *
- * - * Create a new input stream with the given timeout. If the timeout is zero, - * it will be treated as infinite timeout. The socket's channel will be - * configured to be non-blocking. - * - * @see SocketInputStream#SocketInputStream(ReadableByteChannel, long) - * - * @param socket should have a channel associated with it. - * @param timeout timeout timeout in milliseconds. must not be negative. - * @throws IOException - */ - public SocketInputStream(Socket socket, long timeout) throws IOException { - this(socket.getChannel(), timeout); - } - - /** - * Same as SocketInputStream(socket.getChannel(), socket.getSoTimeout()) :
- *
- * - * Create a new input stream with the given timeout. If the timeout is zero, - * it will be treated as infinite timeout. The socket's channel will be - * configured to be non-blocking. - * - * @see SocketInputStream#SocketInputStream(ReadableByteChannel, long) - * - * @param socket should have a channel associated with it. - * @throws IOException - */ - public SocketInputStream(Socket socket) throws IOException { - this(socket.getChannel(), socket.getSoTimeout()); - } - - @Override - public int read() throws IOException { - /* - * Allocation can be removed if required. probably no need to optimize or - * encourage single byte read. - */ - byte[] buf = new byte[1]; - int ret = read(buf, 0, 1); - if (ret > 0) { - return (byte) buf[0]; - } - if (ret != -1) { - // unexpected - throw new IOException("Could not read from stream"); - } - return ret; - } - - public int read(byte[] b, int off, int len) throws IOException { - return read(ByteBuffer.wrap(b, off, len)); - } - - public synchronized void close() throws IOException { - /* - * close the channel since Socket.getInputStream().close() closes the - * socket. - */ - reader.channel.close(); - reader.close(); - } - - /** - * Returns underlying channel used by inputstream. This is useful in certain - * cases like channel for - * {@link FileChannel#transferFrom(ReadableByteChannel, long, long)}. - */ - public ReadableByteChannel getChannel() { - return reader.channel; - } - - // ReadableByteChannel interface - - public boolean isOpen() { - return reader.isOpen(); - } - - public int read(ByteBuffer dst) throws IOException { - return reader.doIO(dst, SelectionKey.OP_READ); - } - - /** - * waits for the underlying channel to be ready for reading. The timeout - * specified for this stream applies to this wait. - * - * @throws SocketTimeoutException if select on the channel times out. - * @throws IOException if any other I/O error occurs. - */ - public void waitForReadable() throws IOException { - reader.waitForIO(SelectionKey.OP_READ); - } -} Index: core/src/main/java/org/apache/hama/util/SocketOutputStream.java =================================================================== --- core/src/main/java/org/apache/hama/util/SocketOutputStream.java (revision 1535330) +++ core/src/main/java/org/apache/hama/util/SocketOutputStream.java (working copy) @@ -1,214 +0,0 @@ -/** - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.hama.util; - -import java.io.EOFException; -import java.io.IOException; -import java.io.OutputStream; -import java.net.Socket; -import java.net.SocketTimeoutException; -import java.nio.ByteBuffer; -import java.nio.channels.FileChannel; -import java.nio.channels.SelectableChannel; -import java.nio.channels.SelectionKey; -import java.nio.channels.WritableByteChannel; - -/** - * This implements an output stream that can have a timeout while writing. This - * sets non-blocking flag on the socket channel. So after creating this object , - * read() on {@link Socket#getInputStream()} and write() on - * {@link Socket#getOutputStream()} on the associated socket will throw - * llegalBlockingModeException. Please use {@link SocketInputStream} for - * reading. - */ -public class SocketOutputStream extends OutputStream implements - WritableByteChannel { - - private Writer writer; - - private static class Writer extends SocketIOWithTimeout { - WritableByteChannel channel; - - Writer(WritableByteChannel channel, long timeout) throws IOException { - super((SelectableChannel) channel, timeout); - this.channel = channel; - } - - int performIO(ByteBuffer buf) throws IOException { - return channel.write(buf); - } - } - - /** - * Create a new ouput stream with the given timeout. If the timeout is zero, - * it will be treated as infinite timeout. The socket's channel will be - * configured to be non-blocking. - * - * @param channel Channel for writing, should also be a - * {@link SelectableChannel}. The channel will be configured to be - * non-blocking. - * @param timeout timeout in milliseconds. must not be negative. - * @throws IOException - */ - public SocketOutputStream(WritableByteChannel channel, long timeout) - throws IOException { - SocketIOWithTimeout.checkChannelValidity(channel); - writer = new Writer(channel, timeout); - } - - /** - * Same as SocketOutputStream(socket.getChannel(), timeout):
- *
- * - * Create a new ouput stream with the given timeout. If the timeout is zero, - * it will be treated as infinite timeout. The socket's channel will be - * configured to be non-blocking. - * - * @see SocketOutputStream#SocketOutputStream(WritableByteChannel, long) - * - * @param socket should have a channel associated with it. - * @param timeout timeout timeout in milliseconds. must not be negative. - * @throws IOException - */ - public SocketOutputStream(Socket socket, long timeout) throws IOException { - this(socket.getChannel(), timeout); - } - - public void write(int b) throws IOException { - /* - * If we need to, we can optimize this allocation. probably no need to - * optimize or encourage single byte writes. - */ - byte[] buf = new byte[1]; - buf[0] = (byte) b; - write(buf, 0, 1); - } - - public void write(byte[] b, int off, int len) throws IOException { - ByteBuffer buf = ByteBuffer.wrap(b, off, len); - while (buf.hasRemaining()) { - try { - if (write(buf) < 0) { - throw new IOException("The stream is closed"); - } - } catch (IOException e) { - /* - * Unlike read, write can not inform user of partial writes. So will - * close this if there was a partial write. - */ - if (buf.capacity() > buf.remaining()) { - writer.close(); - } - throw e; - } - } - } - - public synchronized void close() throws IOException { - /* - * close the channel since Socket.getOuputStream().close() closes the - * socket. - */ - writer.channel.close(); - writer.close(); - } - - /** - * Returns underlying channel used by this stream. This is useful in certain - * cases like channel for - * {@link FileChannel#transferTo(long, long, WritableByteChannel)} - */ - public WritableByteChannel getChannel() { - return writer.channel; - } - - // WritableByteChannle interface - - public boolean isOpen() { - return writer.isOpen(); - } - - public int write(ByteBuffer src) throws IOException { - return writer.doIO(src, SelectionKey.OP_WRITE); - } - - /** - * waits for the underlying channel to be ready for writing. The timeout - * specified for this stream applies to this wait. - * - * @throws SocketTimeoutException if select on the channel times out. - * @throws IOException if any other I/O error occurs. - */ - public void waitForWritable() throws IOException { - writer.waitForIO(SelectionKey.OP_WRITE); - } - - /** - * Transfers data from FileChannel using - * {@link FileChannel#transferTo(long, long, WritableByteChannel)}. - * - * Similar to readFully(), this waits till requested amount of data is - * transfered. - * - * @param fileCh FileChannel to transfer data from. - * @param position position within the channel where the transfer begins - * @param count number of bytes to transfer. - * - * @throws EOFException If end of input file is reached before requested - * number of bytes are transfered. - * - * @throws SocketTimeoutException If this channel blocks transfer longer than - * timeout for this stream. - * - * @throws IOException Includes any exception thrown by - * {@link FileChannel#transferTo(long, long, WritableByteChannel)}. - */ - public void transferToFully(FileChannel fileCh, long position, int count) - throws IOException { - - while (count > 0) { - /* - * Ideally we should wait after transferTo returns 0. But because of a bug - * in JRE on Linux (http://bugs.sun.com/view_bug.do?bug_id=5103988), which - * throws an exception instead of returning 0, we wait for the channel to - * be writable before writing to it. If you ever see IOException with - * message "Resource temporarily unavailable" thrown here, please let us - * know. Once we move to JAVA SE 7, wait should be moved to correct place. - */ - waitForWritable(); - int nTransfered = (int) fileCh.transferTo(position, count, getChannel()); - - if (nTransfered == 0) { - // check if end of file is reached. - if (position >= fileCh.size()) { - throw new EOFException("EOF Reached. file size is " + fileCh.size() - + " and " + count + " more bytes left to be " + "transfered."); - } - // otherwise assume the socket is full. - // waitForWritable(); // see comment above. - } else if (nTransfered < 0) { - throw new IOException("Unexpected return of " + nTransfered - + " from transferTo()"); - } else { - position += nTransfered; - count -= nTransfered; - } - } - } -} Index: core/src/main/java/org/apache/hama/util/StringArrayWritable.java =================================================================== --- core/src/main/java/org/apache/hama/util/StringArrayWritable.java (revision 1535330) +++ core/src/main/java/org/apache/hama/util/StringArrayWritable.java (working copy) @@ -1,65 +0,0 @@ -/** - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.hama.util; - -import java.io.DataInput; -import java.io.DataOutput; -import java.io.IOException; - -import org.apache.hadoop.io.Writable; - -/** - * Custom writable for string arrays, because ArrayWritable has no default - * constructor and is broken. - * - */ -public class StringArrayWritable implements Writable { - - private String[] array; - - public StringArrayWritable() { - super(); - } - - public StringArrayWritable(String[] array) { - super(); - this.array = array; - } - - // no defensive copy needed because this always comes from an rpc call. - public String[] get() { - return array; - } - - @Override - public void write(DataOutput out) throws IOException { - out.writeInt(array.length); - for (String s : array) { - out.writeUTF(s); - } - } - - @Override - public void readFields(DataInput in) throws IOException { - array = new String[in.readInt()]; - for (int i = 0; i < array.length; i++) { - array[i] = in.readUTF(); - } - } - -} Index: core/src/main/java/org/apache/hama/util/KeyValuePair.java =================================================================== --- core/src/main/java/org/apache/hama/util/KeyValuePair.java (revision 1535330) +++ core/src/main/java/org/apache/hama/util/KeyValuePair.java (working copy) @@ -1,59 +0,0 @@ -/** - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.hama.util; - -/** - * Mutable class for key values. - */ -public class KeyValuePair { - - private K key; - private V value; - - public KeyValuePair() { - - } - - public KeyValuePair(K key, V value) { - super(); - this.key = key; - this.value = value; - } - - public K getKey() { - return key; - } - - public V getValue() { - return value; - } - - public void setKey(K key) { - this.key = key; - } - - public void setValue(V value) { - this.value = value; - } - - public void clear() { - this.key = null; - this.value = null; - } - -} Index: core/src/main/java/org/apache/hama/pipes/protocol/StreamingProtocol.java =================================================================== --- core/src/main/java/org/apache/hama/pipes/protocol/StreamingProtocol.java (revision 1535330) +++ core/src/main/java/org/apache/hama/pipes/protocol/StreamingProtocol.java (working copy) @@ -35,7 +35,7 @@ import org.apache.hadoop.io.Writable; import org.apache.hama.bsp.BSPPeer; import org.apache.hama.bsp.sync.SyncException; -import org.apache.hama.util.KeyValuePair; +import org.apache.hama.commons.util.KeyValuePair; /** * Streaming protocol that inherits from the binary protocol. Basically it Index: core/src/main/java/org/apache/hama/pipes/protocol/UplinkReader.java =================================================================== --- core/src/main/java/org/apache/hama/pipes/protocol/UplinkReader.java (revision 1535330) +++ core/src/main/java/org/apache/hama/pipes/protocol/UplinkReader.java (working copy) @@ -44,7 +44,7 @@ import org.apache.hama.Constants; import org.apache.hama.bsp.BSPPeer; import org.apache.hama.bsp.sync.SyncException; -import org.apache.hama.util.KeyValuePair; +import org.apache.hama.commons.util.KeyValuePair; public class UplinkReader extends Thread { Index: core/src/main/java/org/apache/hama/bsp/BSPPeer.java =================================================================== --- core/src/main/java/org/apache/hama/bsp/BSPPeer.java (revision 1535330) +++ core/src/main/java/org/apache/hama/bsp/BSPPeer.java (working copy) @@ -24,7 +24,7 @@ import org.apache.hama.Constants; import org.apache.hama.bsp.Counters.Counter; import org.apache.hama.bsp.sync.SyncException; -import org.apache.hama.util.KeyValuePair; +import org.apache.hama.commons.util.KeyValuePair; /** * BSP communication interface. Reads key-value inputs, with K1 typed keys and Index: core/src/main/java/org/apache/hama/bsp/PartitioningRunner.java =================================================================== --- core/src/main/java/org/apache/hama/bsp/PartitioningRunner.java (revision 1535330) +++ core/src/main/java/org/apache/hama/bsp/PartitioningRunner.java (working copy) @@ -36,7 +36,7 @@ import org.apache.hadoop.util.ReflectionUtils; import org.apache.hama.Constants; import org.apache.hama.bsp.sync.SyncException; -import org.apache.hama.util.KeyValuePair; +import org.apache.hama.commons.util.KeyValuePair; public class PartitioningRunner extends BSP { Index: core/src/main/java/org/apache/hama/bsp/TextArrayWritable.java =================================================================== --- core/src/main/java/org/apache/hama/bsp/TextArrayWritable.java (revision 1535330) +++ core/src/main/java/org/apache/hama/bsp/TextArrayWritable.java (working copy) @@ -1,29 +0,0 @@ -/** - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.hama.bsp; - -import org.apache.hadoop.io.ArrayWritable; -import org.apache.hadoop.io.Text; - -public class TextArrayWritable extends ArrayWritable { - - public TextArrayWritable() { - super(Text.class); - } - -} Index: core/src/main/java/org/apache/hama/bsp/BSPPeerImpl.java =================================================================== --- core/src/main/java/org/apache/hama/bsp/BSPPeerImpl.java (revision 1535330) +++ core/src/main/java/org/apache/hama/bsp/BSPPeerImpl.java (working copy) @@ -46,10 +46,10 @@ import org.apache.hama.bsp.sync.PeerSyncClient; import org.apache.hama.bsp.sync.SyncException; import org.apache.hama.bsp.sync.SyncServiceFactory; +import org.apache.hama.commons.util.KeyValuePair; import org.apache.hama.ipc.BSPPeerProtocol; import org.apache.hama.pipes.util.DistributedCacheUtil; import org.apache.hama.util.DistCacheUtils; -import org.apache.hama.util.KeyValuePair; /** * This class represents a BSP peer. Index: core/src/test/java/org/apache/hama/bsp/TestKeyValueTextInputFormat.java =================================================================== --- core/src/test/java/org/apache/hama/bsp/TestKeyValueTextInputFormat.java (revision 1535330) +++ core/src/test/java/org/apache/hama/bsp/TestKeyValueTextInputFormat.java (working copy) @@ -22,6 +22,8 @@ import java.util.HashMap; import java.util.Map; +import junit.framework.TestCase; + import org.apache.hadoop.conf.Configuration; import org.apache.hadoop.fs.FSDataOutputStream; import org.apache.hadoop.fs.FileSystem; @@ -33,11 +35,9 @@ import org.apache.hadoop.io.Writable; import org.apache.hama.HamaConfiguration; import org.apache.hama.bsp.sync.SyncException; -import org.apache.hama.util.KeyValuePair; +import org.apache.hama.commons.util.KeyValuePair; import org.junit.Test; -import junit.framework.TestCase; - public class TestKeyValueTextInputFormat extends TestCase { public static class KeyValueHashPartitionedBSP Index: core/src/test/java/org/apache/hama/bsp/TestCheckpoint.java =================================================================== --- core/src/test/java/org/apache/hama/bsp/TestCheckpoint.java (revision 1535330) +++ core/src/test/java/org/apache/hama/bsp/TestCheckpoint.java (working copy) @@ -58,8 +58,8 @@ import org.apache.hama.bsp.sync.SyncEventListener; import org.apache.hama.bsp.sync.SyncException; import org.apache.hama.bsp.sync.SyncServiceFactory; +import org.apache.hama.commons.util.KeyValuePair; import org.apache.hama.util.BSPNetUtils; -import org.apache.hama.util.KeyValuePair; public class TestCheckpoint extends TestCase { Index: core/src/test/java/org/apache/hama/bsp/TestPartitioning.java =================================================================== --- core/src/test/java/org/apache/hama/bsp/TestPartitioning.java (revision 1535330) +++ core/src/test/java/org/apache/hama/bsp/TestPartitioning.java (working copy) @@ -32,7 +32,7 @@ import org.apache.hama.HamaConfiguration; import org.apache.hama.bsp.message.queue.DiskQueue; import org.apache.hama.bsp.sync.SyncException; -import org.apache.hama.util.KeyValuePair; +import org.apache.hama.commons.util.KeyValuePair; public class TestPartitioning extends HamaCluster { Index: core/pom.xml =================================================================== --- core/pom.xml (revision 1535330) +++ core/pom.xml (working copy) @@ -32,6 +32,11 @@ + org.apache.hama + hama-commons + ${project.version} + + org.xerial.snappy snappy-java 1.0.5 Index: pom.xml =================================================================== --- pom.xml (revision 1535330) +++ pom.xml (working copy) @@ -282,6 +282,7 @@ c++ + commons core graph examples Index: graph/src/test/java/org/apache/hama/graph/example/PageRank.java =================================================================== --- graph/src/test/java/org/apache/hama/graph/example/PageRank.java (revision 1535330) +++ graph/src/test/java/org/apache/hama/graph/example/PageRank.java (working copy) @@ -28,8 +28,8 @@ import org.apache.hama.HamaConfiguration; import org.apache.hama.bsp.HashPartitioner; import org.apache.hama.bsp.SequenceFileInputFormat; -import org.apache.hama.bsp.TextArrayWritable; import org.apache.hama.bsp.TextOutputFormat; +import org.apache.hama.commons.io.writable.TextArrayWritable; import org.apache.hama.graph.AverageAggregator; import org.apache.hama.graph.Edge; import org.apache.hama.graph.GraphJob; Index: graph/src/test/java/org/apache/hama/graph/TestSubmitGraphJob.java =================================================================== --- graph/src/test/java/org/apache/hama/graph/TestSubmitGraphJob.java (revision 1535330) +++ graph/src/test/java/org/apache/hama/graph/TestSubmitGraphJob.java (working copy) @@ -36,7 +36,7 @@ import org.apache.hama.bsp.SequenceFileInputFormat; import org.apache.hama.bsp.SequenceFileOutputFormat; import org.apache.hama.bsp.TestBSPMasterGroomServer; -import org.apache.hama.bsp.TextArrayWritable; +import org.apache.hama.commons.io.writable.TextArrayWritable; import org.apache.hama.graph.example.PageRank; import org.apache.hama.graph.example.PageRank.PagerankSeqReader; import org.junit.Before; Index: graph/src/main/java/org/apache/hama/graph/VertexInputReader.java =================================================================== --- graph/src/main/java/org/apache/hama/graph/VertexInputReader.java (revision 1535330) +++ graph/src/main/java/org/apache/hama/graph/VertexInputReader.java (working copy) @@ -31,7 +31,7 @@ import org.apache.hama.bsp.BSPPeer; import org.apache.hama.bsp.Partitioner; import org.apache.hama.bsp.PartitioningRunner.RecordConverter; -import org.apache.hama.util.KeyValuePair; +import org.apache.hama.commons.util.KeyValuePair; /** * A reader to read Hama's input files and parses a vertex out of it. Index: graph/src/main/java/org/apache/hama/graph/GraphJobRunner.java =================================================================== --- graph/src/main/java/org/apache/hama/graph/GraphJobRunner.java (revision 1535330) +++ graph/src/main/java/org/apache/hama/graph/GraphJobRunner.java (working copy) @@ -39,8 +39,8 @@ import org.apache.hama.bsp.PartitioningRunner.DefaultRecordConverter; import org.apache.hama.bsp.PartitioningRunner.RecordConverter; import org.apache.hama.bsp.sync.SyncException; +import org.apache.hama.commons.util.KeyValuePair; import org.apache.hama.graph.IDSkippingIterator.Strategy; -import org.apache.hama.util.KeyValuePair; import org.apache.hama.util.ReflectionUtils; /** Index: graph/pom.xml =================================================================== --- graph/pom.xml (revision 1535330) +++ graph/pom.xml (working copy) @@ -33,6 +33,11 @@ org.apache.hama + hama-commons + ${project.version} + + + org.apache.hama hama-core ${project.version} Index: ml/src/test/java/org/apache/hama/ml/kmeans/TestKMeansBSP.java =================================================================== --- ml/src/test/java/org/apache/hama/ml/kmeans/TestKMeansBSP.java (revision 1535330) +++ ml/src/test/java/org/apache/hama/ml/kmeans/TestKMeansBSP.java (working copy) @@ -28,8 +28,7 @@ import org.apache.hadoop.fs.FileSystem; import org.apache.hadoop.fs.Path; import org.apache.hama.bsp.BSPJob; -import org.apache.hama.ml.kmeans.KMeansBSP; -import org.apache.hama.ml.math.DoubleVector; +import org.apache.hama.commons.math.DoubleVector; public class TestKMeansBSP extends TestCase { Index: ml/src/test/java/org/apache/hama/ml/regression/VectorDoubleFileInputFormatTest.java =================================================================== --- ml/src/test/java/org/apache/hama/ml/regression/VectorDoubleFileInputFormatTest.java (revision 1535330) +++ ml/src/test/java/org/apache/hama/ml/regression/VectorDoubleFileInputFormatTest.java (working copy) @@ -17,18 +17,20 @@ */ package org.apache.hama.ml.regression; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertTrue; + import org.apache.hadoop.fs.Path; import org.apache.hadoop.io.DoubleWritable; import org.apache.hama.bsp.BSPJob; import org.apache.hama.bsp.FileSplit; import org.apache.hama.bsp.InputSplit; import org.apache.hama.bsp.RecordReader; -import org.apache.hama.ml.math.DenseDoubleVector; -import org.apache.hama.ml.writable.VectorWritable; +import org.apache.hama.commons.io.writable.VectorWritable; +import org.apache.hama.commons.math.DenseDoubleVector; import org.junit.Test; -import static org.junit.Assert.*; - /** * Testcase for {@link VectorDoubleFileInputFormat} */ Index: ml/src/test/java/org/apache/hama/ml/regression/TestLinearRegression.java =================================================================== --- ml/src/test/java/org/apache/hama/ml/regression/TestLinearRegression.java (revision 1535330) +++ ml/src/test/java/org/apache/hama/ml/regression/TestLinearRegression.java (working copy) @@ -25,8 +25,8 @@ import java.util.ArrayList; import java.util.List; -import org.apache.hama.ml.math.DenseDoubleVector; -import org.apache.hama.ml.math.DoubleVector; +import org.apache.hama.commons.math.DenseDoubleVector; +import org.apache.hama.commons.math.DoubleVector; import org.junit.Test; import org.mortbay.log.Log; Index: ml/src/test/java/org/apache/hama/ml/regression/LinearRegressionModelTest.java =================================================================== --- ml/src/test/java/org/apache/hama/ml/regression/LinearRegressionModelTest.java (revision 1535330) +++ ml/src/test/java/org/apache/hama/ml/regression/LinearRegressionModelTest.java (working copy) @@ -21,8 +21,8 @@ import java.math.BigDecimal; -import org.apache.hama.ml.math.DenseDoubleVector; -import org.apache.hama.ml.math.DoubleVector; +import org.apache.hama.commons.math.DenseDoubleVector; +import org.apache.hama.commons.math.DoubleVector; import org.junit.Test; /** Index: ml/src/test/java/org/apache/hama/ml/regression/TestLogisticRegression.java =================================================================== --- ml/src/test/java/org/apache/hama/ml/regression/TestLogisticRegression.java (revision 1535330) +++ ml/src/test/java/org/apache/hama/ml/regression/TestLogisticRegression.java (working copy) @@ -26,8 +26,8 @@ import java.util.Arrays; import java.util.List; -import org.apache.hama.ml.math.DenseDoubleVector; -import org.apache.hama.ml.math.DoubleVector; +import org.apache.hama.commons.math.DenseDoubleVector; +import org.apache.hama.commons.math.DoubleVector; import org.junit.Test; import org.mortbay.log.Log; Index: ml/src/test/java/org/apache/hama/ml/regression/LogisticRegressionModelTest.java =================================================================== --- ml/src/test/java/org/apache/hama/ml/regression/LogisticRegressionModelTest.java (revision 1535330) +++ ml/src/test/java/org/apache/hama/ml/regression/LogisticRegressionModelTest.java (working copy) @@ -22,8 +22,8 @@ import java.math.BigDecimal; -import org.apache.hama.ml.math.DenseDoubleVector; -import org.apache.hama.ml.math.DoubleVector; +import org.apache.hama.commons.math.DenseDoubleVector; +import org.apache.hama.commons.math.DoubleVector; import org.junit.Test; /** Index: ml/src/test/java/org/apache/hama/ml/perception/TestSmallMultiLayerPerceptron.java =================================================================== --- ml/src/test/java/org/apache/hama/ml/perception/TestSmallMultiLayerPerceptron.java (revision 1535330) +++ ml/src/test/java/org/apache/hama/ml/perception/TestSmallMultiLayerPerceptron.java (working copy) @@ -19,7 +19,6 @@ import static org.junit.Assert.assertArrayEquals; import static org.junit.Assert.assertEquals; -import static org.junit.Assert.fail; import java.io.IOException; import java.net.URI; @@ -34,12 +33,12 @@ import org.apache.hadoop.io.LongWritable; import org.apache.hadoop.io.SequenceFile; import org.apache.hadoop.io.WritableUtils; -import org.apache.hama.ml.math.DenseDoubleMatrix; -import org.apache.hama.ml.math.DenseDoubleVector; -import org.apache.hama.ml.math.DoubleMatrix; -import org.apache.hama.ml.math.DoubleVector; -import org.apache.hama.ml.writable.MatrixWritable; -import org.apache.hama.ml.writable.VectorWritable; +import org.apache.hama.commons.io.writable.MatrixWritable; +import org.apache.hama.commons.io.writable.VectorWritable; +import org.apache.hama.commons.math.DenseDoubleMatrix; +import org.apache.hama.commons.math.DenseDoubleVector; +import org.apache.hama.commons.math.DoubleMatrix; +import org.apache.hama.commons.math.DoubleVector; import org.junit.Test; import org.mortbay.log.Log; Index: ml/src/test/java/org/apache/hama/ml/perception/TestSmallMLPMessage.java =================================================================== --- ml/src/test/java/org/apache/hama/ml/perception/TestSmallMLPMessage.java (revision 1535330) +++ ml/src/test/java/org/apache/hama/ml/perception/TestSmallMLPMessage.java (working copy) @@ -29,7 +29,7 @@ import org.apache.hadoop.fs.FSDataOutputStream; import org.apache.hadoop.fs.FileSystem; import org.apache.hadoop.fs.Path; -import org.apache.hama.ml.math.DenseDoubleMatrix; +import org.apache.hama.commons.math.DenseDoubleMatrix; import org.junit.Test; /** Index: ml/src/test/java/org/apache/hama/ml/math/TestDenseDoubleVector.java =================================================================== --- ml/src/test/java/org/apache/hama/ml/math/TestDenseDoubleVector.java (revision 1535330) +++ ml/src/test/java/org/apache/hama/ml/math/TestDenseDoubleVector.java (working copy) @@ -1,208 +0,0 @@ -/** - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.hama.ml.math; - -import static org.junit.Assert.assertArrayEquals; -import static org.junit.Assert.assertEquals; - -import org.junit.Test; - -/** - * Testcase for {@link DenseDoubleVector} - * - */ -public class TestDenseDoubleVector { - - @Test - public void testApplyDoubleFunction() { - double[] values = new double[] {1, 2, 3, 4, 5}; - double[] result = new double[] {2, 3, 4, 5, 6}; - - DoubleVector vec1 = new DenseDoubleVector(values); - - vec1.applyToElements(new DoubleFunction() { - - @Override - public double apply(double value) { - return value + 1; - } - - @Override - public double applyDerivative(double value) { - throw new UnsupportedOperationException("Not supported."); - } - - }); - - assertArrayEquals(result, vec1.toArray(), 0.0001); - } - - @Test - public void testApplyDoubleDoubleFunction() { - double[] values1 = new double[] {1, 2, 3, 4, 5, 6}; - double[] values2 = new double[] {7, 8, 9, 10, 11, 12}; - double[] result = new double[] {8, 10, 12, 14, 16, 18}; - - DoubleVector vec1 = new DenseDoubleVector(values1); - DoubleVector vec2 = new DenseDoubleVector(values2); - - vec1.applyToElements(vec2, new DoubleDoubleFunction() { - - @Override - public double apply(double x1, double x2) { - return x1 + x2; - } - - @Override - public double applyDerivative(double x1, double x2) { - throw new UnsupportedOperationException("Not supported"); - } - - }); - - assertArrayEquals(result, vec1.toArray(), 0.0001); - - } - - @Test - public void testAddNormal() { - double[] arr1 = new double[] {1, 2, 3}; - double[] arr2 = new double[] {4, 5, 6}; - DoubleVector vec1 = new DenseDoubleVector(arr1); - DoubleVector vec2 = new DenseDoubleVector(arr2); - double[] arrExp = new double[] {5, 7, 9}; - assertArrayEquals(arrExp, vec1.add(vec2).toArray(), 0.000001); - } - - @Test(expected = IllegalArgumentException.class) - public void testAddAbnormal() { - double[] arr1 = new double[] {1, 2, 3}; - double[] arr2 = new double[] {4, 5}; - DoubleVector vec1 = new DenseDoubleVector(arr1); - DoubleVector vec2 = new DenseDoubleVector(arr2); - vec1.add(vec2); - } - - @Test - public void testSubtractNormal() { - double[] arr1 = new double[] {1, 2, 3}; - double[] arr2 = new double[] {4, 5, 6}; - DoubleVector vec1 = new DenseDoubleVector(arr1); - DoubleVector vec2 = new DenseDoubleVector(arr2); - double[] arrExp = new double[] {-3, -3, -3}; - assertArrayEquals(arrExp, vec1.subtract(vec2).toArray(), 0.000001); - } - - @Test(expected = IllegalArgumentException.class) - public void testSubtractAbnormal() { - double[] arr1 = new double[] {1, 2, 3}; - double[] arr2 = new double[] {4, 5}; - DoubleVector vec1 = new DenseDoubleVector(arr1); - DoubleVector vec2 = new DenseDoubleVector(arr2); - vec1.subtract(vec2); - } - - @Test - public void testMultiplyNormal() { - double[] arr1 = new double[] {1, 2, 3}; - double[] arr2 = new double[] {4, 5, 6}; - DoubleVector vec1 = new DenseDoubleVector(arr1); - DoubleVector vec2 = new DenseDoubleVector(arr2); - double[] arrExp = new double[] {4, 10, 18}; - assertArrayEquals(arrExp, vec1.multiply(vec2).toArray(), 0.000001); - } - - @Test(expected = IllegalArgumentException.class) - public void testMultiplyAbnormal() { - double[] arr1 = new double[] {1, 2, 3}; - double[] arr2 = new double[] {4, 5}; - DoubleVector vec1 = new DenseDoubleVector(arr1); - DoubleVector vec2 = new DenseDoubleVector(arr2); - vec1.multiply(vec2); - } - - @Test - public void testDotNormal() { - double[] arr1 = new double[] {1, 2, 3}; - double[] arr2 = new double[] {4, 5, 6}; - DoubleVector vec1 = new DenseDoubleVector(arr1); - DoubleVector vec2 = new DenseDoubleVector(arr2); - assertEquals(32.0, vec1.dot(vec2), 0.000001); - } - - @Test(expected = IllegalArgumentException.class) - public void testDotAbnormal() { - double[] arr1 = new double[] {1, 2, 3}; - double[] arr2 = new double[] {4, 5}; - DoubleVector vec1 = new DenseDoubleVector(arr1); - DoubleVector vec2 = new DenseDoubleVector(arr2); - vec1.add(vec2); - } - - @Test - public void testSliceNormal() { - double[] arr1 = new double[] {2, 3, 4, 5, 6}; - double[] arr2 = new double[] {4, 5, 6}; - double[] arr3 = new double[] {2, 3, 4}; - DoubleVector vec1 = new DenseDoubleVector(arr1); - assertArrayEquals(arr2, vec1.slice(2, 4).toArray(), 0.000001); - DoubleVector vec2 = new DenseDoubleVector(arr1); - assertArrayEquals(arr3, vec2.slice(3).toArray(), 0.000001); - } - - @Test(expected = IllegalArgumentException.class) - public void testSliceAbnormal() { - double[] arr1 = new double[] {2, 3, 4, 5, 6}; - DoubleVector vec = new DenseDoubleVector(arr1); - vec.slice(2, 5); - } - - @Test(expected = IllegalArgumentException.class) - public void testSliceAbnormalEndTooLarge() { - double[] arr1 = new double[] {2, 3, 4, 5, 6}; - DoubleVector vec = new DenseDoubleVector(arr1); - vec.slice(2, 5); - } - - @Test(expected = IllegalArgumentException.class) - public void testSliceAbnormalStartLargerThanEnd() { - double[] arr1 = new double[] {2, 3, 4, 5, 6}; - DoubleVector vec = new DenseDoubleVector(arr1); - vec.slice(4, 3); - } - - @Test - public void testVectorMultiplyMatrix() { - DoubleVector vec = new DenseDoubleVector(new double[]{1, 2, 3}); - DoubleMatrix mat = new DenseDoubleMatrix(new double[][] { - {1, 2, 3, 4}, {5, 6, 7, 8}, {9, 10, 11, 12} - }); - double[] expectedRes = new double[] {38, 44, 50, 56}; - - assertArrayEquals(expectedRes, vec.multiply(mat).toArray(), 0.000001); - } - - @Test(expected = IllegalArgumentException.class) - public void testVectorMultiplyMatrixAbnormal() { - DoubleVector vec = new DenseDoubleVector(new double[]{1, 2, 3}); - DoubleMatrix mat = new DenseDoubleMatrix(new double[][] { - {1, 2, 3, 4}, {5, 6, 7, 8}, {9, 10, 11, 12}, {13, 14, 15, 16} - }); - vec.multiply(mat); - } -} Index: ml/src/test/java/org/apache/hama/ml/math/TestFunctionFactory.java =================================================================== --- ml/src/test/java/org/apache/hama/ml/math/TestFunctionFactory.java (revision 1535330) +++ ml/src/test/java/org/apache/hama/ml/math/TestFunctionFactory.java (working copy) @@ -1,82 +0,0 @@ -/** - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.hama.ml.math; - -import static org.junit.Assert.assertEquals; - -import java.util.Random; - -import org.junit.Test; - -/** - * Test case for {@link FunctionFactory} - * - */ -public class TestFunctionFactory { - - @Test - public void testCreateDoubleFunction() { - double input = 0.8; - - String sigmoidName = "Sigmoid"; - DoubleFunction sigmoidFunction = FunctionFactory - .createDoubleFunction(sigmoidName); - assertEquals(sigmoidName, sigmoidFunction.getFunctionName()); - - double sigmoidExcepted = 0.68997448; - assertEquals(sigmoidExcepted, sigmoidFunction.apply(input), 0.000001); - - - String tanhName = "Tanh"; - DoubleFunction tanhFunction = FunctionFactory.createDoubleFunction(tanhName); - assertEquals(tanhName, tanhFunction.getFunctionName()); - - double tanhExpected = 0.66403677; - assertEquals(tanhExpected, tanhFunction.apply(input), 0.00001); - - - String identityFunctionName = "IdentityFunction"; - DoubleFunction identityFunction = FunctionFactory.createDoubleFunction(identityFunctionName); - - Random rnd = new Random(); - double identityExpected = rnd.nextDouble(); - assertEquals(identityExpected, identityFunction.apply(identityExpected), 0.000001); - } - - @Test - public void testCreateDoubleDoubleFunction() { - double target = 0.5; - double output = 0.8; - - String squaredErrorName = "SquaredError"; - DoubleDoubleFunction squaredErrorFunction = FunctionFactory.createDoubleDoubleFunction(squaredErrorName); - assertEquals(squaredErrorName, squaredErrorFunction.getFunctionName()); - - double squaredErrorExpected = 0.045; - - assertEquals(squaredErrorExpected, squaredErrorFunction.apply(target, output), 0.000001); - - String crossEntropyName = "CrossEntropy"; - DoubleDoubleFunction crossEntropyFunction = FunctionFactory.createDoubleDoubleFunction(crossEntropyName); - assertEquals(crossEntropyName, crossEntropyFunction.getFunctionName()); - - double crossEntropyExpected = 0.91629; - assertEquals(crossEntropyExpected, crossEntropyFunction.apply(target, output), 0.000001); - } - -} Index: ml/src/test/java/org/apache/hama/ml/math/TestDenseDoubleMatrix.java =================================================================== --- ml/src/test/java/org/apache/hama/ml/math/TestDenseDoubleMatrix.java (revision 1535330) +++ ml/src/test/java/org/apache/hama/ml/math/TestDenseDoubleMatrix.java (working copy) @@ -1,239 +0,0 @@ -/** - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.hama.ml.math; - -import static org.junit.Assert.assertArrayEquals; - -import java.util.Arrays; - -import org.junit.Test; - -/** - * Test case for {@link DenseDoubleMatrix} - * - */ -public class TestDenseDoubleMatrix { - - @Test - public void testDoubleFunction() { - double[][] values = new double[][] { { 1, 2, 3 }, { 4, 5, 6 }, { 7, 8, 9 } }; - - double[][] result = new double[][] { { 2, 3, 4 }, { 5, 6, 7 }, { 8, 9, 10 } }; - - DenseDoubleMatrix mat = new DenseDoubleMatrix(values); - mat.applyToElements(new DoubleFunction() { - - @Override - public double apply(double value) { - return value + 1; - } - - @Override - public double applyDerivative(double value) { - throw new UnsupportedOperationException(); - } - - }); - - double[][] actual = mat.getValues(); - for (int i = 0; i < actual.length; ++i) { - assertArrayEquals(result[i], actual[i], 0.0001); - } - } - - @Test - public void testDoubleDoubleFunction() { - double[][] values1 = new double[][] { { 1, 2, 3 }, { 4, 5, 6 }, { 7, 8, 9 } }; - double[][] values2 = new double[][] { { 2, 3, 4 }, { 5, 6, 7 }, - { 8, 9, 10 } }; - double[][] result = new double[][] { { 3, 5, 7 }, { 9, 11, 13 }, - { 15, 17, 19 } }; - - DenseDoubleMatrix mat1 = new DenseDoubleMatrix(values1); - DenseDoubleMatrix mat2 = new DenseDoubleMatrix(values2); - - mat1.applyToElements(mat2, new DoubleDoubleFunction() { - - @Override - public double apply(double x1, double x2) { - return x1 + x2; - } - - @Override - public double applyDerivative(double x1, double x2) { - throw new UnsupportedOperationException(); - } - - }); - - double[][] actual = mat1.getValues(); - for (int i = 0; i < actual.length; ++i) { - assertArrayEquals(result[i], actual[i], 0.0001); - } - } - - @Test - public void testMultiplyNormal() { - double[][] mat1 = new double[][] { { 1, 2, 3 }, { 4, 5, 6 } }; - double[][] mat2 = new double[][] { { 6, 5 }, { 4, 3 }, { 2, 1 } }; - double[][] expMat = new double[][] { { 20, 14 }, { 56, 41 } }; - DoubleMatrix matrix1 = new DenseDoubleMatrix(mat1); - DoubleMatrix matrix2 = new DenseDoubleMatrix(mat2); - DoubleMatrix actMatrix = matrix1.multiply(matrix2); - for (int r = 0; r < actMatrix.getRowCount(); ++r) { - assertArrayEquals(expMat[r], actMatrix.getRowVector(r).toArray(), - 0.000001); - } - } - - @Test(expected = IllegalArgumentException.class) - public void testMultiplyAbnormal() { - double[][] mat1 = new double[][] { { 1, 2, 3 }, { 4, 5, 6 } }; - double[][] mat2 = new double[][] { { 6, 5 }, { 4, 3 } }; - DoubleMatrix matrix1 = new DenseDoubleMatrix(mat1); - DoubleMatrix matrix2 = new DenseDoubleMatrix(mat2); - matrix1.multiply(matrix2); - } - - @Test - public void testMultiplyElementWiseNormal() { - double[][] mat1 = new double[][] { { 1, 2, 3 }, { 4, 5, 6 } }; - double[][] mat2 = new double[][] { { 6, 5, 4 }, { 3, 2, 1 } }; - double[][] expMat = new double[][] { { 6, 10, 12 }, { 12, 10, 6 } }; - DoubleMatrix matrix1 = new DenseDoubleMatrix(mat1); - DoubleMatrix matrix2 = new DenseDoubleMatrix(mat2); - DoubleMatrix actMatrix = matrix1.multiplyElementWise(matrix2); - for (int r = 0; r < actMatrix.getRowCount(); ++r) { - assertArrayEquals(expMat[r], actMatrix.getRowVector(r).toArray(), - 0.000001); - } - } - - @Test(expected = IllegalArgumentException.class) - public void testMultiplyElementWiseAbnormal() { - double[][] mat1 = new double[][] { { 1, 2, 3 }, { 4, 5, 6 } }; - double[][] mat2 = new double[][] { { 6, 5 }, { 4, 3 } }; - DoubleMatrix matrix1 = new DenseDoubleMatrix(mat1); - DoubleMatrix matrix2 = new DenseDoubleMatrix(mat2); - matrix1.multiplyElementWise(matrix2); - } - - @Test - public void testMultiplyVectorNormal() { - double[][] mat1 = new double[][] { { 1, 2, 3 }, { 4, 5, 6 } }; - double[] mat2 = new double[] { 6, 5, 4 }; - double[] expVec = new double[] { 28, 73 }; - DoubleMatrix matrix1 = new DenseDoubleMatrix(mat1); - DoubleVector vector2 = new DenseDoubleVector(mat2); - DoubleVector actVec = matrix1.multiplyVector(vector2); - assertArrayEquals(expVec, actVec.toArray(), 0.000001); - } - - @Test(expected = IllegalArgumentException.class) - public void testMultiplyVectorAbnormal() { - double[][] mat1 = new double[][] { { 1, 2, 3 }, { 4, 5, 6 } }; - double[] vec2 = new double[] { 6, 5 }; - DoubleMatrix matrix1 = new DenseDoubleMatrix(mat1); - DoubleVector vector2 = new DenseDoubleVector(vec2); - matrix1.multiplyVector(vector2); - } - - @Test - public void testSubtractNormal() { - double[][] mat1 = new double[][] { - {1, 2, 3}, - {4, 5, 6} - }; - double[][] mat2 = new double[][] { - {6, 5, 4}, - {3, 2, 1} - }; - double[][] expMat = new double[][] { - {-5, -3, -1}, - {1, 3, 5} - }; - DoubleMatrix matrix1 = new DenseDoubleMatrix(mat1); - DoubleMatrix matrix2 = new DenseDoubleMatrix(mat2); - DoubleMatrix actMatrix = matrix1.subtract(matrix2); - for (int r = 0; r < actMatrix.getRowCount(); ++r) { - assertArrayEquals(expMat[r], actMatrix.getRowVector(r).toArray(), 0.000001); - } - } - - @Test(expected = IllegalArgumentException.class) - public void testSubtractAbnormal() { - double[][] mat1 = new double[][] { - {1, 2, 3}, - {4, 5, 6} - }; - double[][] mat2 = new double[][] { - {6, 5}, - {4, 3} - }; - DoubleMatrix matrix1 = new DenseDoubleMatrix(mat1); - DoubleMatrix matrix2 = new DenseDoubleMatrix(mat2); - matrix1.subtract(matrix2); - } - - @Test - public void testDivideVectorNormal() { - double[][] mat1 = new double[][] { { 1, 2, 3 }, { 4, 5, 6 } }; - double[] mat2 = new double[] { 6, 5, 4 }; - double[][] expVec = new double[][] { {1.0 / 6, 2.0 / 5, 3.0 / 4}, {4.0 / 6, 5.0 / 5, 6.0 / 4} }; - DoubleMatrix matrix1 = new DenseDoubleMatrix(mat1); - DoubleVector vector2 = new DenseDoubleVector(mat2); - DoubleMatrix expMat = new DenseDoubleMatrix(expVec); - DoubleMatrix actMat = matrix1.divide(vector2); - for (int r = 0; r < actMat.getRowCount(); ++r) { - assertArrayEquals(expMat.getRowVector(r).toArray(), actMat.getRowVector(r).toArray(), 0.000001); - } - } - - @Test(expected = IllegalArgumentException.class) - public void testDivideVectorAbnormal() { - double[][] mat1 = new double[][] { { 1, 2, 3 }, { 4, 5, 6 } }; - double[] vec2 = new double[] { 6, 5 }; - DoubleMatrix matrix1 = new DenseDoubleMatrix(mat1); - DoubleVector vector2 = new DenseDoubleVector(vec2); - matrix1.divide(vector2); - } - - @Test - public void testDivideNormal() { - double[][] mat1 = new double[][] { { 1, 2, 3 }, { 4, 5, 6 } }; - double[][] mat2 = new double[][] { { 6, 5, 4 }, { 3, 2, 1 } }; - double[][] expMat = new double[][] { { 1.0 / 6, 2.0 / 5, 3.0 / 4 }, { 4.0 / 3, 5.0 / 2, 6.0 / 1 } }; - DoubleMatrix matrix1 = new DenseDoubleMatrix(mat1); - DoubleMatrix matrix2 = new DenseDoubleMatrix(mat2); - DoubleMatrix actMatrix = matrix1.divide(matrix2); - for (int r = 0; r < actMatrix.getRowCount(); ++r) { - assertArrayEquals(expMat[r], actMatrix.getRowVector(r).toArray(), - 0.000001); - } - } - - @Test(expected = IllegalArgumentException.class) - public void testDivideAbnormal() { - double[][] mat1 = new double[][] { { 1, 2, 3 }, { 4, 5, 6 } }; - double[][] mat2 = new double[][] { { 6, 5 }, { 4, 3 } }; - DoubleMatrix matrix1 = new DenseDoubleMatrix(mat1); - DoubleMatrix matrix2 = new DenseDoubleMatrix(mat2); - matrix1.divide(matrix2); - } - -} Index: ml/src/test/java/org/apache/hama/ml/ann/TestAutoEncoder.java =================================================================== --- ml/src/test/java/org/apache/hama/ml/ann/TestAutoEncoder.java (revision 1535330) +++ ml/src/test/java/org/apache/hama/ml/ann/TestAutoEncoder.java (working copy) @@ -36,10 +36,10 @@ import org.apache.hadoop.fs.Path; import org.apache.hadoop.io.LongWritable; import org.apache.hadoop.io.SequenceFile; +import org.apache.hama.commons.io.writable.VectorWritable; +import org.apache.hama.commons.math.DenseDoubleVector; +import org.apache.hama.commons.math.DoubleVector; import org.apache.hama.ml.MLTestBase; -import org.apache.hama.ml.math.DenseDoubleVector; -import org.apache.hama.ml.math.DoubleVector; -import org.apache.hama.ml.writable.VectorWritable; import org.junit.Ignore; import org.junit.Test; import org.mortbay.log.Log; Index: ml/src/test/java/org/apache/hama/ml/ann/TestSmallLayeredNeuralNetwork.java =================================================================== --- ml/src/test/java/org/apache/hama/ml/ann/TestSmallLayeredNeuralNetwork.java (revision 1535330) +++ ml/src/test/java/org/apache/hama/ml/ann/TestSmallLayeredNeuralNetwork.java (working copy) @@ -38,15 +38,15 @@ import org.apache.hadoop.fs.Path; import org.apache.hadoop.io.LongWritable; import org.apache.hadoop.io.SequenceFile; +import org.apache.hama.commons.io.writable.VectorWritable; +import org.apache.hama.commons.math.DenseDoubleMatrix; +import org.apache.hama.commons.math.DenseDoubleVector; +import org.apache.hama.commons.math.DoubleMatrix; +import org.apache.hama.commons.math.DoubleVector; +import org.apache.hama.commons.math.FunctionFactory; import org.apache.hama.ml.MLTestBase; import org.apache.hama.ml.ann.AbstractLayeredNeuralNetwork.LearningStyle; import org.apache.hama.ml.ann.AbstractLayeredNeuralNetwork.TrainingMethod; -import org.apache.hama.ml.math.DenseDoubleMatrix; -import org.apache.hama.ml.math.DenseDoubleVector; -import org.apache.hama.ml.math.DoubleMatrix; -import org.apache.hama.ml.math.DoubleVector; -import org.apache.hama.ml.math.FunctionFactory; -import org.apache.hama.ml.writable.VectorWritable; import org.junit.Test; import org.mortbay.log.Log; Index: ml/src/test/java/org/apache/hama/ml/ann/TestSmallLayeredNeuralNetworkMessage.java =================================================================== --- ml/src/test/java/org/apache/hama/ml/ann/TestSmallLayeredNeuralNetworkMessage.java (revision 1535330) +++ ml/src/test/java/org/apache/hama/ml/ann/TestSmallLayeredNeuralNetworkMessage.java (working copy) @@ -32,8 +32,8 @@ import org.apache.hadoop.fs.FSDataOutputStream; import org.apache.hadoop.fs.FileSystem; import org.apache.hadoop.fs.Path; -import org.apache.hama.ml.math.DenseDoubleMatrix; -import org.apache.hama.ml.math.DoubleMatrix; +import org.apache.hama.commons.math.DenseDoubleMatrix; +import org.apache.hama.commons.math.DoubleMatrix; import org.junit.Test; /** @@ -86,7 +86,7 @@ DoubleMatrix[] readPrevMatrices = readMessage.getPrevMatrices(); assertNull(readPrevMatrices); - + // delete fs.delete(path, true); } catch (IOException e) { Index: ml/src/main/java/org/apache/hama/ml/perception/SmallMultiLayerPerceptron.java =================================================================== --- ml/src/main/java/org/apache/hama/ml/perception/SmallMultiLayerPerceptron.java (revision 1535330) +++ ml/src/main/java/org/apache/hama/ml/perception/SmallMultiLayerPerceptron.java (working copy) @@ -39,13 +39,13 @@ import org.apache.hadoop.io.WritableUtils; import org.apache.hama.HamaConfiguration; import org.apache.hama.bsp.BSPJob; -import org.apache.hama.ml.math.DenseDoubleMatrix; -import org.apache.hama.ml.math.DenseDoubleVector; -import org.apache.hama.ml.math.DoubleFunction; -import org.apache.hama.ml.math.DoubleVector; -import org.apache.hama.ml.math.FunctionFactory; -import org.apache.hama.ml.writable.MatrixWritable; -import org.apache.hama.ml.writable.VectorWritable; +import org.apache.hama.commons.io.writable.MatrixWritable; +import org.apache.hama.commons.io.writable.VectorWritable; +import org.apache.hama.commons.math.DenseDoubleMatrix; +import org.apache.hama.commons.math.DenseDoubleVector; +import org.apache.hama.commons.math.DoubleFunction; +import org.apache.hama.commons.math.DoubleVector; +import org.apache.hama.commons.math.FunctionFactory; import org.mortbay.log.Log; /** Index: ml/src/main/java/org/apache/hama/ml/perception/PerceptronTrainer.java =================================================================== --- ml/src/main/java/org/apache/hama/ml/perception/PerceptronTrainer.java (revision 1535330) +++ ml/src/main/java/org/apache/hama/ml/perception/PerceptronTrainer.java (working copy) @@ -25,7 +25,7 @@ import org.apache.hama.bsp.BSP; import org.apache.hama.bsp.BSPPeer; import org.apache.hama.bsp.sync.SyncException; -import org.apache.hama.ml.writable.VectorWritable; +import org.apache.hama.commons.io.writable.VectorWritable; /** * The trainer that is used to train the perceptron with BSP. The trainer would Index: ml/src/main/java/org/apache/hama/ml/perception/SmallMLPMessage.java =================================================================== --- ml/src/main/java/org/apache/hama/ml/perception/SmallMLPMessage.java (revision 1535330) +++ ml/src/main/java/org/apache/hama/ml/perception/SmallMLPMessage.java (working copy) @@ -21,8 +21,8 @@ import java.io.DataOutput; import java.io.IOException; -import org.apache.hama.ml.math.DenseDoubleMatrix; -import org.apache.hama.ml.writable.MatrixWritable; +import org.apache.hama.commons.io.writable.MatrixWritable; +import org.apache.hama.commons.math.DenseDoubleMatrix; /** * SmallMLPMessage is used to exchange information for the Index: ml/src/main/java/org/apache/hama/ml/perception/SmallMLPTrainer.java =================================================================== --- ml/src/main/java/org/apache/hama/ml/perception/SmallMLPTrainer.java (revision 1535330) +++ ml/src/main/java/org/apache/hama/ml/perception/SmallMLPTrainer.java (working copy) @@ -25,9 +25,9 @@ import org.apache.hadoop.io.NullWritable; import org.apache.hama.bsp.BSPPeer; import org.apache.hama.bsp.sync.SyncException; +import org.apache.hama.commons.io.writable.VectorWritable; +import org.apache.hama.commons.math.DenseDoubleMatrix; import org.apache.hama.ml.ann.NeuralNetworkTrainer; -import org.apache.hama.ml.math.DenseDoubleMatrix; -import org.apache.hama.ml.writable.VectorWritable; /** * The perceptron trainer for small scale MLP. Index: ml/src/main/java/org/apache/hama/ml/perception/MultiLayerPerceptron.java =================================================================== --- ml/src/main/java/org/apache/hama/ml/perception/MultiLayerPerceptron.java (revision 1535330) +++ ml/src/main/java/org/apache/hama/ml/perception/MultiLayerPerceptron.java (working copy) @@ -21,11 +21,11 @@ import java.util.Map; import org.apache.hadoop.fs.Path; +import org.apache.hama.commons.math.DoubleDoubleFunction; +import org.apache.hama.commons.math.DoubleFunction; +import org.apache.hama.commons.math.DoubleVector; +import org.apache.hama.commons.math.FunctionFactory; import org.apache.hama.ml.ann.NeuralNetworkTrainer; -import org.apache.hama.ml.math.DoubleDoubleFunction; -import org.apache.hama.ml.math.DoubleFunction; -import org.apache.hama.ml.math.DoubleVector; -import org.apache.hama.ml.math.FunctionFactory; /** * PerceptronBase defines the common behavior of all the concrete perceptrons. Index: ml/src/main/java/org/apache/hama/ml/math/IdentityFunction.java =================================================================== --- ml/src/main/java/org/apache/hama/ml/math/IdentityFunction.java (revision 1535330) +++ ml/src/main/java/org/apache/hama/ml/math/IdentityFunction.java (working copy) @@ -1,36 +0,0 @@ -/** - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.hama.ml.math; - -/** - * The identity function f(x) = x. - * - */ -public class IdentityFunction extends DoubleFunction { - - @Override - public double apply(double value) { - return value; - } - - @Override - public double applyDerivative(double value) { - return 1; - } - -} Index: ml/src/main/java/org/apache/hama/ml/math/DoubleMatrix.java =================================================================== --- ml/src/main/java/org/apache/hama/ml/math/DoubleMatrix.java (revision 1535330) +++ ml/src/main/java/org/apache/hama/ml/math/DoubleMatrix.java (working copy) @@ -1,273 +0,0 @@ -/** - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.hama.ml.math; - -/** - * Standard matrix interface for double elements. Every implementation should - * return a fresh new Matrix when operating with other elements. - */ -public interface DoubleMatrix { - - /** - * Not flagged value for sparse matrices, it is default to 0.0d. - */ - public static final double NOT_FLAGGED = 0.0d; - - /** - * Get a specific value of the matrix. - * - * @return Returns the integer value at in the column at the row. - */ - public double get(int row, int col); - - /** - * Returns the number of columns in the matrix. Always a constant time - * operation. - */ - public int getColumnCount(); - - /** - * Get a whole column of the matrix as vector. - */ - public DoubleVector getColumnVector(int col); - - /** - * Returns the number of rows in this matrix. Always a constant time - * operation. - */ - public int getRowCount(); - - /** - * Get a single row of the matrix as a vector. - */ - public DoubleVector getRowVector(int row); - - /** - * Sets the value at the given row and column index. - */ - public void set(int row, int col, double value); - - /** - * Sets a whole column at index col with the given vector. - */ - public void setColumnVector(int col, DoubleVector column); - - /** - * Sets the whole row at index rowIndex with the given vector. - */ - public void setRowVector(int rowIndex, DoubleVector row); - - /** - * Multiplies this matrix (each element) with the given scalar and returns a - * new matrix. - */ - public DoubleMatrix multiply(double scalar); - - /** - * Multiplies this matrix with the given other matrix. - * - * @param other the other matrix. - * @return - */ - public DoubleMatrix multiplyUnsafe(DoubleMatrix other); - - /** - * Validates the input and multiplies this matrix with the given other matrix. - * - * @param other the other matrix. - * @return - */ - public DoubleMatrix multiply(DoubleMatrix other); - - /** - * Multiplies this matrix per element with a given matrix. - */ - public DoubleMatrix multiplyElementWiseUnsafe(DoubleMatrix other); - - /** - * Validates the input and multiplies this matrix per element with a given - * matrix. - * - * @param other the other matrix - * @return - */ - public DoubleMatrix multiplyElementWise(DoubleMatrix other); - - /** - * Multiplies this matrix with a given vector v. The returning vector contains - * the sum of the rows. - */ - public DoubleVector multiplyVectorUnsafe(DoubleVector v); - - /** - * Multiplies this matrix with a given vector v. The returning vector contains - * the sum of the rows. - * - * @param v the vector - * @return - */ - public DoubleVector multiplyVector(DoubleVector v); - - /** - * Transposes this matrix. - */ - public DoubleMatrix transpose(); - - /** - * Substracts the given amount by each element in this matrix.
- * = (amount - matrix value) - */ - public DoubleMatrix subtractBy(double amount); - - /** - * Subtracts each element in this matrix by the given amount.
- * = (matrix value - amount) - */ - public DoubleMatrix subtract(double amount); - - /** - * Subtracts this matrix by the given other matrix. - */ - public DoubleMatrix subtractUnsafe(DoubleMatrix other); - - /** - * Validates the input and subtracts this matrix by the given other matrix. - * - * @param other - * @return - */ - public DoubleMatrix subtract(DoubleMatrix other); - - /** - * Subtracts each element in a column by the related element in the given - * vector. - */ - public DoubleMatrix subtractUnsafe(DoubleVector vec); - - /** - * Validates and subtracts each element in a column by the related element in - * the given vector. - * - * @param vec - * @return - */ - public DoubleMatrix subtract(DoubleVector vec); - - /** - * Divides each element in a column by the related element in the given - * vector. - */ - public DoubleMatrix divideUnsafe(DoubleVector vec); - - /** - * Validates and divides each element in a column by the related element in - * the given vector. - * - * @param vec - * @return - */ - public DoubleMatrix divide(DoubleVector vec); - - /** - * Divides this matrix by the given other matrix. (Per element division). - */ - public DoubleMatrix divideUnsafe(DoubleMatrix other); - - /** - * Validates and divides this matrix by the given other matrix. (Per element - * division). - * - * @param other - * @return - */ - public DoubleMatrix divide(DoubleMatrix other); - - /** - * Divides each element in this matrix by the given scalar. - */ - public DoubleMatrix divide(double scalar); - - /** - * Adds the elements in the given matrix to the elements in this matrix. - */ - public DoubleMatrix add(DoubleMatrix other); - - /** - * Pows each element by the given argument.
- * = (matrix element^x) - */ - public DoubleMatrix pow(int x); - - /** - * Returns the maximum value of the given column. - */ - public double max(int column); - - /** - * Returns the minimum value of the given column. - */ - public double min(int column); - - /** - * Sums all elements. - */ - public double sum(); - - /** - * Returns an array of column indices existing in this matrix. - */ - public int[] columnIndices(); - - /** - * Returns true if the underlying implementation is sparse. - */ - public boolean isSparse(); - - /** - * Slices the given matrix from 0-rows and from 0-columns. - */ - public DoubleMatrix slice(int rows, int cols); - - /** - * Slices the given matrix from rowOffset-rowMax and from colOffset-colMax. - */ - public DoubleMatrix slice(int rowOffset, int rowMax, int colOffset, int colMax); - - /** - * Apply a double function f(x) onto each element of the matrix. After - * applying, each element of the current matrix will be changed from x to - * f(x). - * - * @param fun The function. - * @return The matrix itself, supply for chain operation. - */ - public DoubleMatrix applyToElements(DoubleFunction fun); - - /** - * Apply a double double function f(x, y) onto each pair of the current matrix - * elements and given matrix. After applying, each element of the current - * matrix will be changed from x to f(x, y). - * - * @param other The matrix contributing the second argument of the function. - * @param fun The function that takes two arguments. - * @return The matrix itself, supply for chain operation. - */ - public DoubleMatrix applyToElements(DoubleMatrix other, - DoubleDoubleFunction fun); - -} Index: ml/src/main/java/org/apache/hama/ml/math/DoubleVectorFunction.java =================================================================== --- ml/src/main/java/org/apache/hama/ml/math/DoubleVectorFunction.java (revision 1535330) +++ ml/src/main/java/org/apache/hama/ml/math/DoubleVectorFunction.java (working copy) @@ -1,34 +0,0 @@ -/** - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.hama.ml.math; - -/** - * A function that can be applied to a double vector via {@link DoubleVector} - * #apply({@link DoubleVectorFunction} f); - * - * This class will be replaced by {@link DoubleFunction} - */ -@Deprecated -public interface DoubleVectorFunction { - - /** - * Calculates the result with a given index and value of a vector. - */ - public double calculate(int index, double value); - -} Index: ml/src/main/java/org/apache/hama/ml/math/CrossEntropy.java =================================================================== --- ml/src/main/java/org/apache/hama/ml/math/CrossEntropy.java (revision 1535330) +++ ml/src/main/java/org/apache/hama/ml/math/CrossEntropy.java (working copy) @@ -1,58 +0,0 @@ -/** - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.hama.ml.math; - -/** - * The cross entropy cost function. - * - *
- * cost(t, y) = - t * log(y) - (1 - t) * log(1 - y),
- * where t denotes the target value, y denotes the estimated value.
- * 
- */ -public class CrossEntropy extends DoubleDoubleFunction { - - @Override - public double apply(double target, double actual) { - double adjustedTarget = (target == 0 ? 0.000001 : target); - adjustedTarget = (target == 1.0 ? 0.999999 : target); - double adjustedActual = (actual == 0 ? 0.000001 : actual); - adjustedActual = (actual == 1 ? 0.999999 : actual); - return -adjustedTarget * Math.log(adjustedActual) - (1 - adjustedTarget) - * Math.log(1 - adjustedActual); - } - - @Override - public double applyDerivative(double target, double actual) { - double adjustedTarget = target; - double adjustedActual = actual; - if (adjustedActual == 1) { - adjustedActual = 0.999; - } else if (actual == 0) { - adjustedActual = 0.001; - } - if (adjustedTarget == 1) { - adjustedTarget = 0.999; - } else if (adjustedTarget == 0) { - adjustedTarget = 0.001; - } - return -adjustedTarget / adjustedActual + (1 - adjustedTarget) - / (1 - adjustedActual); - } - -} Index: ml/src/main/java/org/apache/hama/ml/math/DoubleDoubleFunction.java =================================================================== --- ml/src/main/java/org/apache/hama/ml/math/DoubleDoubleFunction.java (revision 1535330) +++ ml/src/main/java/org/apache/hama/ml/math/DoubleDoubleFunction.java (working copy) @@ -1,45 +0,0 @@ -/** - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.hama.ml.math; - -/** - * A double double function takes two arguments. A vector or matrix can apply - * the double function to each element. - * - */ -public abstract class DoubleDoubleFunction extends Function { - - /** - * Apply the function to elements to two given arguments. - * - * @param x1 - * @param x2 - * @return The result based on the calculation on two arguments. - */ - public abstract double apply(double x1, double x2); - - /** - * Apply the derivative of this function to two given arguments. - * - * @param x1 - * @param x2 - * @return The result based on the calculation on two arguments. - */ - public abstract double applyDerivative(double x1, double x2); - -} Index: ml/src/main/java/org/apache/hama/ml/math/DenseDoubleVector.java =================================================================== --- ml/src/main/java/org/apache/hama/ml/math/DenseDoubleVector.java (revision 1535330) +++ ml/src/main/java/org/apache/hama/ml/math/DenseDoubleVector.java (working copy) @@ -1,739 +0,0 @@ -/** - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.hama.ml.math; - -import java.math.BigDecimal; -import java.util.ArrayList; -import java.util.Arrays; -import java.util.Collections; -import java.util.Comparator; -import java.util.Iterator; -import java.util.List; - -import com.google.common.base.Preconditions; -import com.google.common.collect.AbstractIterator; - -/** - * Dense double vector implementation. - */ -public final class DenseDoubleVector implements DoubleVector { - - private final double[] vector; - - /** - * Creates a new vector with the given length. - */ - public DenseDoubleVector(int length) { - this.vector = new double[length]; - } - - /** - * Creates a new vector with the given length and default value. - */ - public DenseDoubleVector(int length, double val) { - this(length); - Arrays.fill(vector, val); - } - - /** - * Creates a new vector with the given array. - */ - public DenseDoubleVector(double[] arr) { - this.vector = arr; - } - - /** - * Creates a new vector with the given array and the last value f1. - */ - public DenseDoubleVector(double[] array, double f1) { - this.vector = new double[array.length + 1]; - System.arraycopy(array, 0, this.vector, 0, array.length); - this.vector[array.length] = f1; - } - - /* - * (non-Javadoc) - * @see de.jungblut.math.DoubleVector#get(int) - */ - @Override - public final double get(int index) { - return vector[index]; - } - - /* - * (non-Javadoc) - * @see de.jungblut.math.DoubleVector#getLength() - */ - @Override - public final int getLength() { - return vector.length; - } - - /* - * (non-Javadoc) - * @see de.jungblut.math.DoubleVector#getDimension() - */ - @Override - public int getDimension() { - return getLength(); - } - - /* - * (non-Javadoc) - * @see de.jungblut.math.DoubleVector#set(int, double) - */ - @Override - public final void set(int index, double value) { - vector[index] = value; - } - - /** - * {@inheritDoc} - */ - @Override - public DoubleVector applyToElements(DoubleFunction func) { - for (int i = 0; i < vector.length; i++) { - this.vector[i] = func.apply(vector[i]); - } - return this; - } - - /** - * {@inheritDoc} - */ - @Override - public DoubleVector applyToElements(DoubleVector other, - DoubleDoubleFunction func) { - for (int i = 0; i < vector.length; i++) { - this.vector[i] = func.apply(vector[i], other.get(i)); - } - return this; - } - - /* - * (non-Javadoc) - * @see de.jungblut.math.DoubleVector#apply(de.jungblut.math.function. - * DoubleVectorFunction) - */ - @Deprecated - @Override - public DoubleVector apply(DoubleVectorFunction func) { - DenseDoubleVector newV = new DenseDoubleVector(this.vector); - for (int i = 0; i < vector.length; i++) { - newV.vector[i] = func.calculate(i, vector[i]); - } - return newV; - } - - /* - * (non-Javadoc) - * @see de.jungblut.math.DoubleVector#apply(de.jungblut.math.DoubleVector, - * de.jungblut.math.function.DoubleDoubleVectorFunction) - */ - @Deprecated - @Override - public DoubleVector apply(DoubleVector other, DoubleDoubleVectorFunction func) { - DenseDoubleVector newV = (DenseDoubleVector) deepCopy(); - for (int i = 0; i < vector.length; i++) { - newV.vector[i] = func.calculate(i, vector[i], other.get(i)); - } - return newV; - } - - /* - * (non-Javadoc) - * @see de.jungblut.math.DoubleVector#add(de.jungblut.math.DoubleVector) - */ - @Override - public final DoubleVector addUnsafe(DoubleVector v) { - DenseDoubleVector newv = new DenseDoubleVector(v.getLength()); - for (int i = 0; i < v.getLength(); i++) { - newv.set(i, this.get(i) + v.get(i)); - } - return newv; - } - - /* - * (non-Javadoc) - * @see de.jungblut.math.DoubleVector#add(double) - */ - @Override - public final DoubleVector add(double scalar) { - DoubleVector newv = new DenseDoubleVector(this.getLength()); - for (int i = 0; i < this.getLength(); i++) { - newv.set(i, this.get(i) + scalar); - } - return newv; - } - - /* - * (non-Javadoc) - * @see de.jungblut.math.DoubleVector#subtract(de.jungblut.math.DoubleVector) - */ - @Override - public final DoubleVector subtractUnsafe(DoubleVector v) { - DoubleVector newv = new DenseDoubleVector(v.getLength()); - for (int i = 0; i < v.getLength(); i++) { - newv.set(i, this.get(i) - v.get(i)); - } - return newv; - } - - /* - * (non-Javadoc) - * @see de.jungblut.math.DoubleVector#subtract(double) - */ - @Override - public final DoubleVector subtract(double v) { - DenseDoubleVector newv = new DenseDoubleVector(vector.length); - for (int i = 0; i < vector.length; i++) { - newv.set(i, vector[i] - v); - } - return newv; - } - - /* - * (non-Javadoc) - * @see de.jungblut.math.DoubleVector#subtractFrom(double) - */ - @Override - public final DoubleVector subtractFrom(double v) { - DenseDoubleVector newv = new DenseDoubleVector(vector.length); - for (int i = 0; i < vector.length; i++) { - newv.set(i, v - vector[i]); - } - return newv; - } - - /* - * (non-Javadoc) - * @see de.jungblut.math.DoubleVector#multiply(double) - */ - @Override - public DoubleVector multiply(double scalar) { - DoubleVector v = new DenseDoubleVector(this.getLength()); - for (int i = 0; i < v.getLength(); i++) { - v.set(i, this.get(i) * scalar); - } - return v; - } - - /* - * (non-Javadoc) - * @see de.jungblut.math.DoubleVector#multiply(de.jungblut.math.DoubleVector) - */ - @Override - public DoubleVector multiplyUnsafe(DoubleVector vector) { - DoubleVector v = new DenseDoubleVector(this.getLength()); - for (int i = 0; i < v.getLength(); i++) { - v.set(i, this.get(i) * vector.get(i)); - } - return v; - } - - @Override - public DoubleVector multiply(DoubleMatrix matrix) { - Preconditions.checkArgument(this.vector.length == matrix.getRowCount(), - "Dimension mismatch when multiply a vector to a matrix."); - return this.multiplyUnsafe(matrix); - } - - @Override - public DoubleVector multiplyUnsafe(DoubleMatrix matrix) { - DoubleVector vec = new DenseDoubleVector(matrix.getColumnCount()); - for (int i = 0; i < vec.getDimension(); ++i) { - vec.set(i, this.multiplyUnsafe(matrix.getColumnVector(i)).sum()); - } - return vec; - } - - /* - * (non-Javadoc) - * @see de.jungblut.math.DoubleVector#divide(double) - */ - @Override - public DoubleVector divide(double scalar) { - DenseDoubleVector v = new DenseDoubleVector(this.getLength()); - for (int i = 0; i < v.getLength(); i++) { - v.set(i, this.get(i) / scalar); - } - return v; - } - - /* - * (non-Javadoc) - * @see de.jungblut.math.DoubleVector#pow(int) - */ - @Override - public DoubleVector pow(int x) { - DenseDoubleVector v = new DenseDoubleVector(getLength()); - for (int i = 0; i < v.getLength(); i++) { - double value = 0.0d; - // it is faster to multiply when we having ^2 - if (x == 2) { - value = vector[i] * vector[i]; - } else { - value = Math.pow(vector[i], x); - } - v.set(i, value); - } - return v; - } - - /* - * (non-Javadoc) - * @see de.jungblut.math.DoubleVector#sqrt() - */ - @Override - public DoubleVector sqrt() { - DoubleVector v = new DenseDoubleVector(getLength()); - for (int i = 0; i < v.getLength(); i++) { - v.set(i, Math.sqrt(vector[i])); - } - return v; - } - - /* - * (non-Javadoc) - * @see de.jungblut.math.DoubleVector#sum() - */ - @Override - public double sum() { - double sum = 0.0d; - for (double aVector : vector) { - sum += aVector; - } - return sum; - } - - /* - * (non-Javadoc) - * @see de.jungblut.math.DoubleVector#abs() - */ - @Override - public DoubleVector abs() { - DoubleVector v = new DenseDoubleVector(getLength()); - for (int i = 0; i < v.getLength(); i++) { - v.set(i, Math.abs(vector[i])); - } - return v; - } - - /* - * (non-Javadoc) - * @see de.jungblut.math.DoubleVector#divideFrom(double) - */ - @Override - public DoubleVector divideFrom(double scalar) { - DoubleVector v = new DenseDoubleVector(this.getLength()); - for (int i = 0; i < v.getLength(); i++) { - if (this.get(i) != 0.0d) { - double result = scalar / this.get(i); - v.set(i, result); - } else { - v.set(i, 0.0d); - } - } - return v; - } - - /* - * (non-Javadoc) - * @see de.jungblut.math.DoubleVector#dot(de.jungblut.math.DoubleVector) - */ - @Override - public double dotUnsafe(DoubleVector vector) { - BigDecimal dotProduct = BigDecimal.valueOf(0.0d); - for (int i = 0; i < getLength(); i++) { - dotProduct = dotProduct.add(BigDecimal.valueOf(this.get(i)).multiply(BigDecimal.valueOf(vector.get(i)))); - } - return dotProduct.doubleValue(); - } - - /* - * (non-Javadoc) - * @see de.jungblut.math.DoubleVector#slice(int) - */ - @Override - public DoubleVector slice(int length) { - return slice(0, length - 1); - } - - @Override - public DoubleVector sliceUnsafe(int length) { - return sliceUnsafe(0, length - 1); - } - - /* - * (non-Javadoc) - * @see de.jungblut.math.DoubleVector#slice(int, int) - */ - @Override - public DoubleVector slice(int start, int end) { - Preconditions.checkArgument(start >= 0 && start <= end - && end < vector.length, "The given from and to is invalid"); - - return sliceUnsafe(start, end); - } - - /** - * {@inheritDoc} - */ - @Override - public DoubleVector sliceUnsafe(int start, int end) { - DoubleVector newVec = new DenseDoubleVector(end - start + 1); - for (int i = start, j = 0; i <= end; ++i, ++j) { - newVec.set(j, vector[i]); - } - - return newVec; - } - - /* - * (non-Javadoc) - * @see de.jungblut.math.DoubleVector#max() - */ - @Override - public double max() { - double max = -Double.MAX_VALUE; - for (int i = 0; i < getLength(); i++) { - double d = vector[i]; - if (d > max) { - max = d; - } - } - return max; - } - - /** - * @return the index where the maximum resides. - */ - public int maxIndex() { - double max = -Double.MAX_VALUE; - int maxIndex = 0; - for (int i = 0; i < getLength(); i++) { - double d = vector[i]; - if (d > max) { - max = d; - maxIndex = i; - } - } - return maxIndex; - } - - /* - * (non-Javadoc) - * @see de.jungblut.math.DoubleVector#min() - */ - @Override - public double min() { - double min = Double.MAX_VALUE; - for (int i = 0; i < getLength(); i++) { - double d = vector[i]; - if (d < min) { - min = d; - } - } - return min; - } - - /** - * @return the index where the minimum resides. - */ - public int minIndex() { - double min = Double.MAX_VALUE; - int minIndex = 0; - for (int i = 0; i < getLength(); i++) { - double d = vector[i]; - if (d < min) { - min = d; - minIndex = i; - } - } - return minIndex; - } - - /** - * @return a new vector which has rinted each element. - */ - public DenseDoubleVector rint() { - DenseDoubleVector v = new DenseDoubleVector(getLength()); - for (int i = 0; i < getLength(); i++) { - double d = vector[i]; - v.set(i, Math.rint(d)); - } - return v; - } - - /** - * @return a new vector which has rounded each element. - */ - public DenseDoubleVector round() { - DenseDoubleVector v = new DenseDoubleVector(getLength()); - for (int i = 0; i < getLength(); i++) { - double d = vector[i]; - v.set(i, Math.round(d)); - } - return v; - } - - /** - * @return a new vector which has ceiled each element. - */ - public DenseDoubleVector ceil() { - DenseDoubleVector v = new DenseDoubleVector(getLength()); - for (int i = 0; i < getLength(); i++) { - double d = vector[i]; - v.set(i, Math.ceil(d)); - } - return v; - } - - /* - * (non-Javadoc) - * @see de.jungblut.math.DoubleVector#toArray() - */ - @Override - public final double[] toArray() { - return vector; - } - - /* - * (non-Javadoc) - * @see de.jungblut.math.DoubleVector#isSparse() - */ - @Override - public boolean isSparse() { - return false; - } - - /* - * (non-Javadoc) - * @see de.jungblut.math.DoubleVector#deepCopy() - */ - @Override - public DoubleVector deepCopy() { - final double[] src = vector; - final double[] dest = new double[vector.length]; - System.arraycopy(src, 0, dest, 0, vector.length); - return new DenseDoubleVector(dest); - } - - /* - * (non-Javadoc) - * @see de.jungblut.math.DoubleVector#iterateNonZero() - */ - @Override - public Iterator iterateNonZero() { - return new NonZeroIterator(); - } - - /* - * (non-Javadoc) - * @see de.jungblut.math.DoubleVector#iterate() - */ - @Override - public Iterator iterate() { - return new DefaultIterator(); - } - - @Override - public final String toString() { - if (getLength() < 20) { - return Arrays.toString(vector); - } else { - return getLength() + "x1"; - } - } - - @Override - public int hashCode() { - final int prime = 31; - int result = 1; - result = prime * result + Arrays.hashCode(vector); - return result; - } - - @Override - public boolean equals(Object obj) { - if (this == obj) - return true; - if (obj == null) - return false; - if (getClass() != obj.getClass()) - return false; - DenseDoubleVector other = (DenseDoubleVector) obj; - return Arrays.equals(vector, other.vector); - } - - /** - * Non-zero iterator for vector elements. - */ - private final class NonZeroIterator extends - AbstractIterator { - - private final DoubleVectorElement element = new DoubleVectorElement(); - private final double[] array; - private int currentIndex = 0; - - private NonZeroIterator() { - this.array = vector; - } - - @Override - protected final DoubleVectorElement computeNext() { - while (array[currentIndex] == 0.0d) { - currentIndex++; - if (currentIndex >= array.length) - return endOfData(); - } - element.setIndex(currentIndex); - element.setValue(array[currentIndex]); - return element; - } - } - - /** - * Iterator for all elements. - */ - private final class DefaultIterator extends - AbstractIterator { - - private final DoubleVectorElement element = new DoubleVectorElement(); - private final double[] array; - private int currentIndex = 0; - - private DefaultIterator() { - this.array = vector; - } - - @Override - protected final DoubleVectorElement computeNext() { - if (currentIndex < array.length) { - element.setIndex(currentIndex); - element.setValue(array[currentIndex]); - currentIndex++; - return element; - } else { - return endOfData(); - } - } - - } - - /** - * @return a new vector with dimension num and a default value of 1. - */ - public static DenseDoubleVector ones(int num) { - return new DenseDoubleVector(num, 1.0d); - } - - /** - * @return a new vector filled from index, to index, with a given stepsize. - */ - public static DenseDoubleVector fromUpTo(double from, double to, - double stepsize) { - DenseDoubleVector v = new DenseDoubleVector( - (int) (Math.round(((to - from) / stepsize) + 0.5))); - - for (int i = 0; i < v.getLength(); i++) { - v.set(i, from + i * stepsize); - } - return v; - } - - /** - * Some crazy sort function. - */ - public static List> sort(DoubleVector vector, - final Comparator scoreComparator) { - List> list = new ArrayList>( - vector.getLength()); - for (int i = 0; i < vector.getLength(); i++) { - list.add(new Tuple(vector.get(i), i)); - } - Collections.sort(list, new Comparator>() { - @Override - public int compare(Tuple o1, Tuple o2) { - return scoreComparator.compare(o1.getFirst(), o2.getFirst()); - } - }); - return list; - } - - @Override - public boolean isNamed() { - return false; - } - - @Override - public String getName() { - return null; - } - - /* - * (non-Javadoc) - * @see org.apache.hama.ml.math.DoubleVector#safeAdd(org.apache.hama.ml.math. - * DoubleVector) - */ - @Override - public DoubleVector add(DoubleVector vector) { - Preconditions.checkArgument(this.vector.length == vector.getDimension(), - "Dimensions of two vectors do not equal."); - return this.addUnsafe(vector); - } - - /* - * (non-Javadoc) - * @see - * org.apache.hama.ml.math.DoubleVector#safeSubtract(org.apache.hama.ml.math - * .DoubleVector) - */ - @Override - public DoubleVector subtract(DoubleVector vector) { - Preconditions.checkArgument(this.vector.length == vector.getDimension(), - "Dimensions of two vectors do not equal."); - return this.subtractUnsafe(vector); - } - - /* - * (non-Javadoc) - * @see - * org.apache.hama.ml.math.DoubleVector#safeMultiplay(org.apache.hama.ml.math - * .DoubleVector) - */ - @Override - public DoubleVector multiply(DoubleVector vector) { - Preconditions.checkArgument(this.vector.length == vector.getDimension(), - "Dimensions of two vectors do not equal."); - return this.multiplyUnsafe(vector); - } - - /* - * (non-Javadoc) - * @see org.apache.hama.ml.math.DoubleVector#safeDot(org.apache.hama.ml.math. - * DoubleVector) - */ - @Override - public double dot(DoubleVector vector) { - Preconditions.checkArgument(this.vector.length == vector.getDimension(), - "Dimensions of two vectors do not equal."); - return this.dotUnsafe(vector); - } - -} Index: ml/src/main/java/org/apache/hama/ml/math/FunctionFactory.java =================================================================== --- ml/src/main/java/org/apache/hama/ml/math/FunctionFactory.java (revision 1535330) +++ ml/src/main/java/org/apache/hama/ml/math/FunctionFactory.java (working copy) @@ -1,65 +0,0 @@ -/** - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.hama.ml.math; - -/** - * Factory to create the functions. - * - */ -public class FunctionFactory { - - /** - * Create a double function with specified name. - * - * @param functionName - * @return - */ - public static DoubleFunction createDoubleFunction(String functionName) { - if (functionName.equalsIgnoreCase(Sigmoid.class.getSimpleName())) { - return new Sigmoid(); - } else if (functionName.equalsIgnoreCase(Tanh.class.getSimpleName())) { - return new Tanh(); - } else if (functionName.equalsIgnoreCase(IdentityFunction.class - .getSimpleName())) { - return new IdentityFunction(); - } - - throw new IllegalArgumentException(String.format( - "No double function with name '%s' exists.", functionName)); - } - - /** - * Create a double double function with specified name. - * - * @param functionName - * @return - */ - public static DoubleDoubleFunction createDoubleDoubleFunction( - String functionName) { - if (functionName.equalsIgnoreCase(SquaredError.class.getSimpleName())) { - return new SquaredError(); - } else if (functionName - .equalsIgnoreCase(CrossEntropy.class.getSimpleName())) { - return new CrossEntropy(); - } - - throw new IllegalArgumentException(String.format( - "No double double function with name '%s' exists.", functionName)); - } - -} Index: ml/src/main/java/org/apache/hama/ml/math/SquaredError.java =================================================================== --- ml/src/main/java/org/apache/hama/ml/math/SquaredError.java (revision 1535330) +++ ml/src/main/java/org/apache/hama/ml/math/SquaredError.java (working copy) @@ -1,46 +0,0 @@ -/** - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.hama.ml.math; - -/** - * Square error cost function. - * - *
- * cost(t, y) = 0.5 * (t - y) ˆ 2
- * 
- */ -public class SquaredError extends DoubleDoubleFunction { - - @Override - /** - * {@inheritDoc} - */ - public double apply(double target, double actual) { - double diff = target - actual; - return 0.5 * diff * diff; - } - - @Override - /** - * {@inheritDoc} - */ - public double applyDerivative(double target, double actual) { - return actual - target; - } - -} Index: ml/src/main/java/org/apache/hama/ml/math/DenseDoubleMatrix.java =================================================================== --- ml/src/main/java/org/apache/hama/ml/math/DenseDoubleMatrix.java (revision 1535330) +++ ml/src/main/java/org/apache/hama/ml/math/DenseDoubleMatrix.java (working copy) @@ -1,904 +0,0 @@ -/** - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.hama.ml.math; - -import java.util.Arrays; -import java.util.HashSet; -import java.util.Random; - -import com.google.common.base.Preconditions; - -/** - * Dense double matrix implementation, internally uses two dimensional double - * arrays. - */ -public final class DenseDoubleMatrix implements DoubleMatrix { - - protected final double[][] matrix; - protected final int numRows; - protected final int numColumns; - - /** - * Creates a new empty matrix from the rows and columns. - * - * @param rows the num of rows. - * @param columns the num of columns. - */ - public DenseDoubleMatrix(int rows, int columns) { - this.numRows = rows; - this.numColumns = columns; - this.matrix = new double[rows][columns]; - } - - /** - * Creates a new empty matrix from the rows and columns filled with the given - * default value. - * - * @param rows the num of rows. - * @param columns the num of columns. - * @param defaultValue the default value. - */ - public DenseDoubleMatrix(int rows, int columns, double defaultValue) { - this.numRows = rows; - this.numColumns = columns; - this.matrix = new double[rows][columns]; - - for (int i = 0; i < numRows; i++) { - Arrays.fill(matrix[i], defaultValue); - } - } - - /** - * Creates a new empty matrix from the rows and columns filled with the given - * random values. - * - * @param rows the num of rows. - * @param columns the num of columns. - * @param rand the random instance to use. - */ - public DenseDoubleMatrix(int rows, int columns, Random rand) { - this.numRows = rows; - this.numColumns = columns; - this.matrix = new double[rows][columns]; - - for (int i = 0; i < numRows; i++) { - for (int j = 0; j < numColumns; j++) { - matrix[i][j] = rand.nextDouble(); - } - } - } - - /** - * Simple copy constructor, but does only bend the reference to this instance. - * - * @param otherMatrix the other matrix. - */ - public DenseDoubleMatrix(double[][] otherMatrix) { - this.matrix = otherMatrix; - this.numRows = otherMatrix.length; - if (matrix.length > 0) - this.numColumns = matrix[0].length; - else - this.numColumns = numRows; - } - - /** - * Generates a matrix out of an vector array. it treats the array entries as - * rows and the vector itself contains the values of the columns. - * - * @param vectorArray the array of vectors. - */ - public DenseDoubleMatrix(DoubleVector[] vectorArray) { - this.matrix = new double[vectorArray.length][]; - this.numRows = vectorArray.length; - - for (int i = 0; i < vectorArray.length; i++) { - this.setRowVector(i, vectorArray[i]); - } - - if (matrix.length > 0) - this.numColumns = matrix[0].length; - else - this.numColumns = numRows; - } - - /** - * Sets the first column of this matrix to the given vector. - * - * @param first the new first column of the given vector - */ - public DenseDoubleMatrix(DenseDoubleVector first) { - this(first.getLength(), 1); - setColumn(0, first.toArray()); - } - - /** - * Copies the given double array v into the first row of this matrix, and - * creates this with the number of given rows and columns. - * - * @param v the values to put into the first row. - * @param rows the number of rows. - * @param columns the number of columns. - */ - public DenseDoubleMatrix(double[] v, int rows, int columns) { - this.matrix = new double[rows][columns]; - - for (int i = 0; i < rows; i++) { - System.arraycopy(v, i * columns, this.matrix[i], 0, columns); - } - - int index = 0; - for (int col = 0; col < columns; col++) { - for (int row = 0; row < rows; row++) { - matrix[row][col] = v[index++]; - } - } - - this.numRows = rows; - this.numColumns = columns; - } - - /** - * Creates a new matrix with the given vector into the first column and the - * other matrix to the other columns. - * - * @param first the new first column. - * @param otherMatrix the other matrix to set on from the second column. - */ - public DenseDoubleMatrix(DenseDoubleVector first, DoubleMatrix otherMatrix) { - this(otherMatrix.getRowCount(), otherMatrix.getColumnCount() + 1); - setColumn(0, first.toArray()); - for (int col = 1; col < otherMatrix.getColumnCount() + 1; col++) - setColumnVector(col, otherMatrix.getColumnVector(col - 1)); - } - - /* - * (non-Javadoc) - * @see de.jungblut.math.DoubleMatrix#get(int, int) - */ - @Override - public final double get(int row, int col) { - return this.matrix[row][col]; - } - - /** - * Gets a whole column of the matrix as a double array. - */ - public final double[] getColumn(int col) { - final double[] column = new double[numRows]; - for (int r = 0; r < numRows; r++) { - column[r] = matrix[r][col]; - } - return column; - } - - /* - * (non-Javadoc) - * @see de.jungblut.math.DoubleMatrix#getColumnCount() - */ - @Override - public final int getColumnCount() { - return numColumns; - } - - /* - * (non-Javadoc) - * @see de.jungblut.math.DoubleMatrix#getColumnVector(int) - */ - @Override - public final DoubleVector getColumnVector(int col) { - return new DenseDoubleVector(getColumn(col)); - } - - /** - * Get the matrix as 2-dimensional double array (first dimension is the row, - * second the column) to faster access the values. - */ - public final double[][] getValues() { - return matrix; - } - - /** - * Get a single row of the matrix as a double array. - */ - public final double[] getRow(int row) { - return matrix[row]; - } - - /* - * (non-Javadoc) - * @see de.jungblut.math.DoubleMatrix#getRowCount() - */ - @Override - public final int getRowCount() { - return numRows; - } - - /* - * (non-Javadoc) - * @see de.jungblut.math.DoubleMatrix#getRowVector(int) - */ - @Override - public final DoubleVector getRowVector(int row) { - return new DenseDoubleVector(getRow(row)); - } - - /* - * (non-Javadoc) - * @see de.jungblut.math.DoubleMatrix#set(int, int, double) - */ - @Override - public final void set(int row, int col, double value) { - this.matrix[row][col] = value; - } - - /** - * Sets the row to a given double array. This does not copy, rather than just - * bends the references. - */ - public final void setRow(int row, double[] value) { - this.matrix[row] = value; - } - - /** - * Sets the column to a given double array. This does not copy, rather than - * just bends the references. - */ - public final void setColumn(int col, double[] values) { - for (int i = 0; i < getRowCount(); i++) { - this.matrix[i][col] = values[i]; - } - } - - /* - * (non-Javadoc) - * @see de.jungblut.math.DoubleMatrix#setColumnVector(int, - * de.jungblut.math.DoubleVector) - */ - @Override - public void setColumnVector(int col, DoubleVector column) { - this.setColumn(col, column.toArray()); - } - - /* - * (non-Javadoc) - * @see de.jungblut.math.DoubleMatrix#setRowVector(int, - * de.jungblut.math.DoubleVector) - */ - @Override - public void setRowVector(int rowIndex, DoubleVector row) { - this.setRow(rowIndex, row.toArray()); - } - - /** - * Returns the size of the matrix as string (ROWSxCOLUMNS). - */ - public String sizeToString() { - return numRows + "x" + numColumns; - } - - /** - * Splits the last column from this matrix. Usually used to get a prediction - * column from some machine learning problem. - * - * @return a tuple of a new sliced matrix and a vector which was the last - * column. - */ - public final Tuple splitLastColumn() { - DenseDoubleMatrix m = new DenseDoubleMatrix(getRowCount(), - getColumnCount() - 1); - for (int i = 0; i < getRowCount(); i++) { - for (int j = 0; j < getColumnCount() - 1; j++) { - m.set(i, j, get(i, j)); - } - } - DenseDoubleVector v = new DenseDoubleVector(getColumn(getColumnCount() - 1)); - return new Tuple(m, v); - } - - /** - * Creates two matrices out of this by the given percentage. It uses a random - * function to determine which rows should belong to the matrix including the - * given percentage amount of rows. - * - * @param percentage A float value between 0.0f and 1.0f - * @return A tuple which includes two matrices, the first contains the - * percentage of the rows from the original matrix (rows are chosen - * randomly) and the second one contains all other rows. - */ - public final Tuple splitRandomMatrices( - float percentage) { - if (percentage < 0.0f || percentage > 1.0f) { - throw new IllegalArgumentException( - "Percentage must be between 0.0 and 1.0! Given " + percentage); - } - - if (percentage == 1.0f) { - return new Tuple(this, null); - } else if (percentage == 0.0f) { - return new Tuple(null, this); - } - - final Random rand = new Random(System.nanoTime()); - int firstMatrixRowsCount = Math.round(percentage * numRows); - - // we first choose needed rows number of items to pick - final HashSet lowerMatrixRowIndices = new HashSet(); - int missingRows = firstMatrixRowsCount; - while (missingRows > 0) { - final int nextIndex = rand.nextInt(numRows); - if (lowerMatrixRowIndices.add(nextIndex)) { - missingRows--; - } - } - - // make to new matrixes - final double[][] firstMatrix = new double[firstMatrixRowsCount][numColumns]; - int firstMatrixIndex = 0; - final double[][] secondMatrix = new double[numRows - firstMatrixRowsCount][numColumns]; - int secondMatrixIndex = 0; - - // then we loop over all items and put split the matrix - for (int r = 0; r < numRows; r++) { - if (lowerMatrixRowIndices.contains(r)) { - firstMatrix[firstMatrixIndex++] = matrix[r]; - } else { - secondMatrix[secondMatrixIndex++] = matrix[r]; - } - } - - return new Tuple( - new DenseDoubleMatrix(firstMatrix), new DenseDoubleMatrix(secondMatrix)); - } - - /* - * (non-Javadoc) - * @see de.jungblut.math.DoubleMatrix#multiply(double) - */ - @Override - public final DenseDoubleMatrix multiply(double scalar) { - DenseDoubleMatrix m = new DenseDoubleMatrix(this.numRows, this.numColumns); - for (int i = 0; i < numRows; i++) { - for (int j = 0; j < numColumns; j++) { - m.set(i, j, this.matrix[i][j] * scalar); - } - } - return m; - } - - /* - * (non-Javadoc) - * @see de.jungblut.math.DoubleMatrix#multiply(de.jungblut.math.DoubleMatrix) - */ - @Override - public final DoubleMatrix multiplyUnsafe(DoubleMatrix other) { - DenseDoubleMatrix matrix = new DenseDoubleMatrix(this.getRowCount(), - other.getColumnCount()); - - final int m = this.numRows; - final int n = this.numColumns; - final int p = other.getColumnCount(); - - for (int j = p; --j >= 0;) { - for (int i = m; --i >= 0;) { - double s = 0; - for (int k = n; --k >= 0;) { - s += get(i, k) * other.get(k, j); - } - matrix.set(i, j, s + matrix.get(i, j)); - } - } - - return matrix; - } - - /* - * (non-Javadoc) - * @see - * de.jungblut.math.DoubleMatrix#multiplyElementWise(de.jungblut.math.DoubleMatrix - * ) - */ - @Override - public final DoubleMatrix multiplyElementWiseUnsafe(DoubleMatrix other) { - DenseDoubleMatrix matrix = new DenseDoubleMatrix(this.numRows, - this.numColumns); - - for (int i = 0; i < numRows; i++) { - for (int j = 0; j < numColumns; j++) { - matrix.set(i, j, this.get(i, j) * (other.get(i, j))); - } - } - - return matrix; - } - - /* - * (non-Javadoc) - * @see - * de.jungblut.math.DoubleMatrix#multiplyVector(de.jungblut.math.DoubleVector) - */ - @Override - public final DoubleVector multiplyVectorUnsafe(DoubleVector v) { - DoubleVector vector = new DenseDoubleVector(this.getRowCount()); - for (int row = 0; row < numRows; row++) { - double sum = 0.0d; - for (int col = 0; col < numColumns; col++) { - sum += (matrix[row][col] * v.get(col)); - } - vector.set(row, sum); - } - - return vector; - } - - /* - * (non-Javadoc) - * @see de.jungblut.math.DoubleMatrix#transpose() - */ - @Override - public DenseDoubleMatrix transpose() { - DenseDoubleMatrix m = new DenseDoubleMatrix(this.numColumns, this.numRows); - for (int i = 0; i < numRows; i++) { - for (int j = 0; j < numColumns; j++) { - m.set(j, i, this.matrix[i][j]); - } - } - return m; - } - - /* - * (non-Javadoc) - * @see de.jungblut.math.DoubleMatrix#subtractBy(double) - */ - @Override - public DenseDoubleMatrix subtractBy(double amount) { - DenseDoubleMatrix m = new DenseDoubleMatrix(this.numRows, this.numColumns); - for (int i = 0; i < numRows; i++) { - for (int j = 0; j < numColumns; j++) { - m.set(i, j, amount - this.matrix[i][j]); - } - } - return m; - } - - /* - * (non-Javadoc) - * @see de.jungblut.math.DoubleMatrix#subtract(double) - */ - @Override - public DenseDoubleMatrix subtract(double amount) { - DenseDoubleMatrix m = new DenseDoubleMatrix(this.numRows, this.numColumns); - for (int i = 0; i < numRows; i++) { - for (int j = 0; j < numColumns; j++) { - m.set(i, j, this.matrix[i][j] - amount); - } - } - return m; - } - - /* - * (non-Javadoc) - * @see de.jungblut.math.DoubleMatrix#subtract(de.jungblut.math.DoubleMatrix) - */ - @Override - public DoubleMatrix subtractUnsafe(DoubleMatrix other) { - DoubleMatrix m = new DenseDoubleMatrix(this.numRows, this.numColumns); - for (int i = 0; i < numRows; i++) { - for (int j = 0; j < numColumns; j++) { - m.set(i, j, this.matrix[i][j] - other.get(i, j)); - } - } - return m; - } - - /* - * (non-Javadoc) - * @see de.jungblut.math.DoubleMatrix#subtract(de.jungblut.math.DoubleVector) - */ - @Override - public DenseDoubleMatrix subtractUnsafe(DoubleVector vec) { - DenseDoubleMatrix cop = new DenseDoubleMatrix(this.getRowCount(), - this.getColumnCount()); - for (int i = 0; i < this.getColumnCount(); i++) { - cop.setColumn(i, getColumnVector(i).subtract(vec.get(i)).toArray()); - } - return cop; - } - - /* - * (non-Javadoc) - * @see de.jungblut.math.DoubleMatrix#divide(de.jungblut.math.DoubleVector) - */ - @Override - public DoubleMatrix divideUnsafe(DoubleVector vec) { - DoubleMatrix cop = new DenseDoubleMatrix(this.getRowCount(), - this.getColumnCount()); - for (int i = 0; i < this.getColumnCount(); i++) { - cop.setColumnVector(i, getColumnVector(i).divide(vec.get(i))); - } - return cop; - } - - /** - * {@inheritDoc} - */ - @Override - public DoubleMatrix divide(DoubleVector vec) { - Preconditions.checkArgument(this.getColumnCount() == vec.getDimension(), - "Dimension mismatch."); - return this.divideUnsafe(vec); - } - - /* - * (non-Javadoc) - * @see de.jungblut.math.DoubleMatrix#divide(de.jungblut.math.DoubleMatrix) - */ - @Override - public DoubleMatrix divideUnsafe(DoubleMatrix other) { - DoubleMatrix m = new DenseDoubleMatrix(this.numRows, this.numColumns); - for (int i = 0; i < numRows; i++) { - for (int j = 0; j < numColumns; j++) { - m.set(i, j, this.matrix[i][j] / other.get(i, j)); - } - } - return m; - } - - @Override - public DoubleMatrix divide(DoubleMatrix other) { - Preconditions.checkArgument(this.getRowCount() == other.getRowCount() - && this.getColumnCount() == other.getColumnCount()); - return divideUnsafe(other); - } - - /* - * (non-Javadoc) - * @see de.jungblut.math.DoubleMatrix#divide(double) - */ - @Override - public DoubleMatrix divide(double scalar) { - DoubleMatrix m = new DenseDoubleMatrix(this.numRows, this.numColumns); - for (int i = 0; i < numRows; i++) { - for (int j = 0; j < numColumns; j++) { - m.set(i, j, this.matrix[i][j] / scalar); - } - } - return m; - } - - /* - * (non-Javadoc) - * @see de.jungblut.math.DoubleMatrix#add(de.jungblut.math.DoubleMatrix) - */ - @Override - public DoubleMatrix add(DoubleMatrix other) { - DoubleMatrix m = new DenseDoubleMatrix(this.numRows, this.numColumns); - for (int i = 0; i < numRows; i++) { - for (int j = 0; j < numColumns; j++) { - m.set(i, j, this.matrix[i][j] + other.get(i, j)); - } - } - return m; - } - - /* - * (non-Javadoc) - * @see de.jungblut.math.DoubleMatrix#pow(int) - */ - @Override - public DoubleMatrix pow(int x) { - DoubleMatrix m = new DenseDoubleMatrix(this.numRows, this.numColumns); - for (int i = 0; i < numRows; i++) { - for (int j = 0; j < numColumns; j++) { - m.set(i, j, Math.pow(matrix[i][j], x)); - } - } - return m; - } - - /* - * (non-Javadoc) - * @see de.jungblut.math.DoubleMatrix#max(int) - */ - @Override - public double max(int column) { - double max = Double.MIN_VALUE; - for (int i = 0; i < getRowCount(); i++) { - double d = matrix[i][column]; - if (d > max) { - max = d; - } - } - return max; - } - - /* - * (non-Javadoc) - * @see de.jungblut.math.DoubleMatrix#min(int) - */ - @Override - public double min(int column) { - double min = Double.MAX_VALUE; - for (int i = 0; i < getRowCount(); i++) { - double d = matrix[i][column]; - if (d < min) { - min = d; - } - } - return min; - } - - /* - * (non-Javadoc) - * @see de.jungblut.math.DoubleMatrix#slice(int, int) - */ - @Override - public DoubleMatrix slice(int rows, int cols) { - return slice(0, rows, 0, cols); - } - - /* - * (non-Javadoc) - * @see de.jungblut.math.DoubleMatrix#slice(int, int, int, int) - */ - @Override - public DoubleMatrix slice(int rowOffset, int rowMax, int colOffset, int colMax) { - DenseDoubleMatrix m = new DenseDoubleMatrix(rowMax - rowOffset, colMax - - colOffset); - for (int row = rowOffset; row < rowMax; row++) { - for (int col = colOffset; col < colMax; col++) { - m.set(row - rowOffset, col - colOffset, this.get(row, col)); - } - } - return m; - } - - /* - * (non-Javadoc) - * @see de.jungblut.math.DoubleMatrix#isSparse() - */ - @Override - public boolean isSparse() { - return false; - } - - /* - * (non-Javadoc) - * @see de.jungblut.math.DoubleMatrix#sum() - */ - @Override - public double sum() { - double x = 0.0d; - for (int i = 0; i < numRows; i++) { - for (int j = 0; j < numColumns; j++) { - x += Math.abs(matrix[i][j]); - } - } - return x; - } - - /* - * (non-Javadoc) - * @see de.jungblut.math.DoubleMatrix#columnIndices() - */ - @Override - public int[] columnIndices() { - int[] x = new int[getColumnCount()]; - for (int i = 0; i < getColumnCount(); i++) - x[i] = i; - return x; - } - - @Override - public int hashCode() { - final int prime = 31; - int result = 1; - result = prime * result + Arrays.hashCode(matrix); - result = prime * result + numColumns; - result = prime * result + numRows; - return result; - } - - @Override - public boolean equals(Object obj) { - if (this == obj) - return true; - if (obj == null) - return false; - if (getClass() != obj.getClass()) - return false; - DenseDoubleMatrix other = (DenseDoubleMatrix) obj; - if (!Arrays.deepEquals(matrix, other.matrix)) - return false; - if (numColumns != other.numColumns) - return false; - return numRows == other.numRows; - } - - @Override - public String toString() { - if (numRows < 10) { - StringBuilder sb = new StringBuilder(); - for (int i = 0; i < numRows; i++) { - sb.append(Arrays.toString(matrix[i])); - sb.append('\n'); - } - return sb.toString(); - } else { - return numRows + "x" + numColumns; - } - } - - /** - * Gets the eye matrix (ones on the main diagonal) with a given dimension. - */ - public static DenseDoubleMatrix eye(int dimension) { - DenseDoubleMatrix m = new DenseDoubleMatrix(dimension, dimension); - - for (int i = 0; i < dimension; i++) { - m.set(i, i, 1); - } - - return m; - } - - /** - * Deep copies the given matrix into a new returned one. - */ - public static DenseDoubleMatrix copy(DenseDoubleMatrix matrix) { - final double[][] src = matrix.getValues(); - final double[][] dest = new double[matrix.getRowCount()][matrix - .getColumnCount()]; - - for (int i = 0; i < dest.length; i++) - System.arraycopy(src[i], 0, dest[i], 0, src[i].length); - - return new DenseDoubleMatrix(dest); - } - - /** - * Some strange function I found in octave but I don't know what it was named. - * It does however multiply the elements from the transposed vector and the - * normal vector and sets it into the according indices of a new constructed - * matrix. - */ - public static DenseDoubleMatrix multiplyTransposedVectors( - DoubleVector transposed, DoubleVector normal) { - DenseDoubleMatrix m = new DenseDoubleMatrix(transposed.getLength(), - normal.getLength()); - for (int row = 0; row < transposed.getLength(); row++) { - for (int col = 0; col < normal.getLength(); col++) { - m.set(row, col, transposed.get(row) * normal.get(col)); - } - } - - return m; - } - - /** - * Just a absolute error function. - */ - public static double error(DenseDoubleMatrix a, DenseDoubleMatrix b) { - return a.subtractUnsafe(b).sum(); - } - - @Override - /** - * {@inheritDoc} - */ - public DoubleMatrix applyToElements(DoubleFunction fun) { - for (int r = 0; r < this.numRows; ++r) { - for (int c = 0; c < this.numColumns; ++c) { - this.set(r, c, fun.apply(this.get(r, c))); - } - } - return this; - } - - @Override - /** - * {@inheritDoc} - */ - public DoubleMatrix applyToElements(DoubleMatrix other, - DoubleDoubleFunction fun) { - Preconditions - .checkArgument(this.numRows == other.getRowCount() - && this.numColumns == other.getColumnCount(), - "Cannot apply double double function to matrices with different sizes."); - - for (int r = 0; r < this.numRows; ++r) { - for (int c = 0; c < this.numColumns; ++c) { - this.set(r, c, fun.apply(this.get(r, c), other.get(r, c))); - } - } - - return this; - } - - /* - * (non-Javadoc) - * @see - * org.apache.hama.ml.math.DoubleMatrix#safeMultiply(org.apache.hama.ml.math - * .DoubleMatrix) - */ - @Override - public DoubleMatrix multiply(DoubleMatrix other) { - Preconditions - .checkArgument( - this.numColumns == other.getRowCount(), - String - .format( - "Matrix with size [%d, %d] cannot multiple matrix with size [%d, %d]", - this.numRows, this.numColumns, other.getRowCount(), - other.getColumnCount())); - - return this.multiplyUnsafe(other); - } - - /* - * (non-Javadoc) - * @see - * org.apache.hama.ml.math.DoubleMatrix#safeMultiplyElementWise(org.apache - * .hama.ml.math.DoubleMatrix) - */ - @Override - public DoubleMatrix multiplyElementWise(DoubleMatrix other) { - Preconditions.checkArgument(this.numRows == other.getRowCount() - && this.numColumns == other.getColumnCount(), - "Matrices with different dimensions cannot be multiplied elementwise."); - return this.multiplyElementWiseUnsafe(other); - } - - /* - * (non-Javadoc) - * @see - * org.apache.hama.ml.math.DoubleMatrix#safeMultiplyVector(org.apache.hama - * .ml.math.DoubleVector) - */ - @Override - public DoubleVector multiplyVector(DoubleVector v) { - Preconditions.checkArgument(this.numColumns == v.getDimension(), - "Dimension mismatch."); - return this.multiplyVectorUnsafe(v); - } - - /* - * (non-Javadoc) - * @see org.apache.hama.ml.math.DoubleMatrix#subtract(org.apache.hama.ml.math. - * DoubleMatrix) - */ - @Override - public DoubleMatrix subtract(DoubleMatrix other) { - Preconditions.checkArgument(this.numRows == other.getRowCount() - && this.numColumns == other.getColumnCount(), "Dimension mismatch."); - return subtractUnsafe(other); - } - - /* - * (non-Javadoc) - * @see org.apache.hama.ml.math.DoubleMatrix#subtract(org.apache.hama.ml.math. - * DoubleVector) - */ - @Override - public DoubleMatrix subtract(DoubleVector vec) { - Preconditions.checkArgument(this.numColumns == vec.getDimension(), - "Dimension mismatch."); - return null; - } - -} Index: ml/src/main/java/org/apache/hama/ml/math/DoubleFunction.java =================================================================== --- ml/src/main/java/org/apache/hama/ml/math/DoubleFunction.java (revision 1535330) +++ ml/src/main/java/org/apache/hama/ml/math/DoubleFunction.java (working copy) @@ -1,43 +0,0 @@ -/** - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.hama.ml.math; - -/** - * A double double function takes two arguments. A vector or matrix can apply - * the double function to each element. - * - */ -public abstract class DoubleFunction extends Function { - - /** - * Apply the function to element. - * - * @param elem The element that the function apply to. - * @return The result after applying the function. - */ - public abstract double apply(double value); - - /** - * Apply the gradient of the function. - * - * @param elem - * @return - */ - public abstract double applyDerivative(double value); - -} Index: ml/src/main/java/org/apache/hama/ml/math/Function.java =================================================================== --- ml/src/main/java/org/apache/hama/ml/math/Function.java (revision 1535330) +++ ml/src/main/java/org/apache/hama/ml/math/Function.java (working copy) @@ -1,33 +0,0 @@ -/** - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.hama.ml.math; - -/** - * A generic function. - * - */ -public abstract class Function { - /** - * Get the name of the function. - * - * @return The name of the function. - */ - final public String getFunctionName() { - return this.getClass().getSimpleName(); - } -} Index: ml/src/main/java/org/apache/hama/ml/math/Tuple.java =================================================================== --- ml/src/main/java/org/apache/hama/ml/math/Tuple.java (revision 1535330) +++ ml/src/main/java/org/apache/hama/ml/math/Tuple.java (working copy) @@ -1,85 +0,0 @@ -/** - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.hama.ml.math; - -/** - * Tuple class to hold two generic attributes. This class implements hashcode, - * equals and comparable via the first element. - */ -public final class Tuple implements - Comparable> { - - private final FIRST first; - private final SECOND second; - - public Tuple(FIRST first, SECOND second) { - super(); - this.first = first; - this.second = second; - } - - public final FIRST getFirst() { - return first; - } - - public final SECOND getSecond() { - return second; - } - - @Override - public int hashCode() { - final int prime = 31; - int result = 1; - result = prime * result + ((first == null) ? 0 : first.hashCode()); - return result; - } - - @Override - public boolean equals(Object obj) { - if (this == obj) - return true; - if (obj == null) - return false; - if (getClass() != obj.getClass()) - return false; - @SuppressWarnings("rawtypes") - Tuple other = (Tuple) obj; - if (first == null) { - if (other.first != null) - return false; - } else if (!first.equals(other.first)) - return false; - return true; - } - - @SuppressWarnings("unchecked") - @Override - public int compareTo(Tuple o) { - if (o.getFirst() instanceof Comparable && getFirst() instanceof Comparable) { - return ((Comparable) getFirst()).compareTo(o.getFirst()); - } else { - return 0; - } - } - - @Override - public String toString() { - return "Tuple [first=" + first + ", second=" + second + "]"; - } - -} Index: ml/src/main/java/org/apache/hama/ml/math/Tanh.java =================================================================== --- ml/src/main/java/org/apache/hama/ml/math/Tanh.java (revision 1535330) +++ ml/src/main/java/org/apache/hama/ml/math/Tanh.java (working copy) @@ -1,36 +0,0 @@ -/** - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.hama.ml.math; - -/** - * Tanh function. - * - */ -public class Tanh extends DoubleFunction { - - @Override - public double apply(double value) { - return Math.tanh(value); - } - - @Override - public double applyDerivative(double value) { - return 1 - value * value; - } - -} Index: ml/src/main/java/org/apache/hama/ml/math/Sigmoid.java =================================================================== --- ml/src/main/java/org/apache/hama/ml/math/Sigmoid.java (revision 1535330) +++ ml/src/main/java/org/apache/hama/ml/math/Sigmoid.java (working copy) @@ -1,39 +0,0 @@ -/** - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.hama.ml.math; - -/** - * The Sigmoid function - * - *
- * f(x) = 1 / (1 + e^{-x})
- * 
- */ -public class Sigmoid extends DoubleFunction { - - @Override - public double apply(double value) { - return 1.0 / (1 + Math.exp(-value)); - } - - @Override - public double applyDerivative(double value) { - return value * (1 - value); - } - -} Index: ml/src/main/java/org/apache/hama/ml/math/DoubleVector.java =================================================================== --- ml/src/main/java/org/apache/hama/ml/math/DoubleVector.java (revision 1535330) +++ ml/src/main/java/org/apache/hama/ml/math/DoubleVector.java (working copy) @@ -1,388 +0,0 @@ -/** - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.hama.ml.math; - -import java.util.Iterator; - -/** - * Vector with doubles. Some of the operations are mutable, unlike the apply and - * math functions, they return a fresh instance every time. - * - */ -public interface DoubleVector { - - /** - * Retrieves the value at given index. - * - * @param index the index. - * @return a double value at the index. - */ - public double get(int index); - - /** - * Get the length of a vector, for sparse instance it is the actual length. - * (not the dimension!) Always a constant time operation. - * - * @return the length of the vector. - */ - public int getLength(); - - /** - * Get the dimension of a vector, for dense instance it is the same like the - * length, for sparse instances it is usually not the same. Always a constant - * time operation. - * - * @return the dimension of the vector. - */ - public int getDimension(); - - /** - * Set a value at the given index. - * - * @param index the index of the vector to set. - * @param value the value at the index of the vector to set. - */ - public void set(int index, double value); - - /** - * Apply a given {@link DoubleVectorFunction} to this vector and return a new - * one. - * - * @param func the function to apply. - * @return a new vector with the applied function. - */ - @Deprecated - public DoubleVector apply(DoubleVectorFunction func); - - /** - * Apply a given {@link DoubleDoubleVectorFunction} to this vector and the - * other given vector. - * - * @param other the other vector. - * @param func the function to apply on this and the other vector. - * @return a new vector with the result of the function of the two vectors. - */ - @Deprecated - public DoubleVector apply(DoubleVector other, DoubleDoubleVectorFunction func); - - /** - * Apply a given {@link DoubleVectorFunction} to this vector and return a new - * one. - * - * @param func the function to apply. - * @return a new vector with the applied function. - */ - public DoubleVector applyToElements(DoubleFunction func); - - /** - * Apply a given {@link DoubleDoubleVectorFunction} to this vector and the - * other given vector. - * - * @param other the other vector. - * @param func the function to apply on this and the other vector. - * @return a new vector with the result of the function of the two vectors. - */ - public DoubleVector applyToElements(DoubleVector other, - DoubleDoubleFunction func); - - /** - * Adds the given {@link DoubleVector} to this vector. - * - * @param vector the other vector. - * @return a new vector with the sum of both vectors at each element index. - */ - public DoubleVector addUnsafe(DoubleVector vector); - - /** - * Validates the input and adds the given {@link DoubleVector} to this vector. - * - * @param vector the other vector. - * @return a new vector with the sum of both vectors at each element index. - */ - public DoubleVector add(DoubleVector vector); - - /** - * Adds the given scalar to this vector. - * - * @param scalar the scalar. - * @return a new vector with the result at each element index. - */ - public DoubleVector add(double scalar); - - /** - * Subtracts this vector by the given {@link DoubleVector}. - * - * @param vector the other vector. - * @return a new vector with the difference of both vectors. - */ - public DoubleVector subtractUnsafe(DoubleVector vector); - - /** - * Validates the input and subtracts this vector by the given - * {@link DoubleVector}. - * - * @param vector the other vector. - * @return a new vector with the difference of both vectors. - */ - public DoubleVector subtract(DoubleVector vector); - - /** - * Subtracts the given scalar to this vector. (vector - scalar). - * - * @param scalar the scalar. - * @return a new vector with the result at each element index. - */ - public DoubleVector subtract(double scalar); - - /** - * Subtracts the given scalar from this vector. (scalar - vector). - * - * @param scalar the scalar. - * @return a new vector with the result at each element index. - */ - public DoubleVector subtractFrom(double scalar); - - /** - * Multiplies the given scalar to this vector. - * - * @param scalar the scalar. - * @return a new vector with the result of the operation. - */ - public DoubleVector multiply(double scalar); - - /** - * Multiplies the given {@link DoubleVector} with this vector. - * - * @param vector the other vector. - * @return a new vector with the result of the operation. - */ - public DoubleVector multiplyUnsafe(DoubleVector vector); - - /** - * Validates the input and multiplies the given {@link DoubleVector} with this - * vector. - * - * @param vector the other vector. - * @return a new vector with the result of the operation. - */ - public DoubleVector multiply(DoubleVector vector); - - /** - * Validates the input and multiplies the given {@link DoubleMatrix} with this - * vector. - * - * @param matrix - * @return - */ - public DoubleVector multiply(DoubleMatrix matrix); - - /** - * Multiplies the given {@link DoubleMatrix} with this vector. - * - * @param matrix - * @return - */ - public DoubleVector multiplyUnsafe(DoubleMatrix matrix); - - /** - * Divides this vector by the given scalar. (= vector/scalar). - * - * @param scalar the given scalar. - * @return a new vector with the result of the operation. - */ - public DoubleVector divide(double scalar); - - /** - * Divides the given scalar by this vector. (= scalar/vector). - * - * @param scalar the given scalar. - * @return a new vector with the result of the operation. - */ - public DoubleVector divideFrom(double scalar); - - /** - * Powers this vector by the given amount. (=vector^x). - * - * @param x the given exponent. - * @return a new vector with the result of the operation. - */ - public DoubleVector pow(int x); - - /** - * Absolutes the vector at each element. - * - * @return a new vector that does not contain negative values anymore. - */ - public DoubleVector abs(); - - /** - * Square-roots each element. - * - * @return a new vector. - */ - public DoubleVector sqrt(); - - /** - * @return the sum of all elements in this vector. - */ - public double sum(); - - /** - * Calculates the dot product between this vector and the given vector. - * - * @param vector the given vector. - * @return the dot product as a double. - */ - public double dotUnsafe(DoubleVector vector); - - /** - * Validates the input and calculates the dot product between this vector and - * the given vector. - * - * @param vector the given vector. - * @return the dot product as a double. - */ - public double dot(DoubleVector vector); - - /** - * Validates the input and slices this vector from index 0 to the given - * length. - * - * @param length must be > 0 and smaller than the dimension of the vector. - * @return a new vector that is only length long. - */ - public DoubleVector slice(int length); - - /** - * Slices this vector from index 0 to the given length. - * - * @param length must be > 0 and smaller than the dimension of the vector. - * @return a new vector that is only length long. - */ - public DoubleVector sliceUnsafe(int length); - - /** - * Validates the input and then slices this vector from start to end, both are - * INCLUSIVE. For example vec = [0, 1, 2, 3, 4, 5], vec.slice(2, 5) = [2, 3, - * 4, 5]. - * - * @param offset must be > 0 and smaller than the dimension of the vector - * @param length must be > 0 and smaller than the dimension of the vector. - * This must be greater than the offset. - * @return a new vector that is only (length) long. - */ - public DoubleVector slice(int start, int end); - - /** - * Slices this vector from start to end, both are INCLUSIVE. For example vec = - * [0, 1, 2, 3, 4, 5], vec.slice(2, 5) = [2, 3, 4, 5]. - * - * @param offset must be > 0 and smaller than the dimension of the vector - * @param length must be > 0 and smaller than the dimension of the vector. - * This must be greater than the offset. - * @return a new vector that is only (length) long. - */ - public DoubleVector sliceUnsafe(int start, int end); - - /** - * @return the maximum element value in this vector. - */ - public double max(); - - /** - * @return the minimum element value in this vector. - */ - public double min(); - - /** - * @return an array representation of this vector. - */ - public double[] toArray(); - - /** - * @return a fresh new copy of this vector, copies all elements to a new - * vector. (Does not reuse references or stuff). - */ - public DoubleVector deepCopy(); - - /** - * @return an iterator that only iterates over non zero elements. - */ - public Iterator iterateNonZero(); - - /** - * @return an iterator that iterates over all elements. - */ - public Iterator iterate(); - - /** - * @return true if this instance is a sparse vector. Smarter and faster than - * instanceof. - */ - public boolean isSparse(); - - /** - * @return true if this instance is a named vector.Smarter and faster than - * instanceof. - */ - public boolean isNamed(); - - /** - * @return If this vector is a named instance, this will return its name. Or - * null if this is not a named instance. - * - */ - public String getName(); - - /** - * Class for iteration of elements, consists of an index and a value at this - * index. Can be reused for performance purposes. - */ - public static final class DoubleVectorElement { - - private int index; - private double value; - - public DoubleVectorElement() { - super(); - } - - public DoubleVectorElement(int index, double value) { - super(); - this.index = index; - this.value = value; - } - - public final int getIndex() { - return index; - } - - public final double getValue() { - return value; - } - - public final void setIndex(int in) { - this.index = in; - } - - public final void setValue(double in) { - this.value = in; - } - } - -} Index: ml/src/main/java/org/apache/hama/ml/math/DoubleDoubleVectorFunction.java =================================================================== --- ml/src/main/java/org/apache/hama/ml/math/DoubleDoubleVectorFunction.java (revision 1535330) +++ ml/src/main/java/org/apache/hama/ml/math/DoubleDoubleVectorFunction.java (working copy) @@ -1,35 +0,0 @@ -/** - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.hama.ml.math; - -/** - * A function that can be applied to two double vectors via {@link DoubleVector} - * #apply({@link DoubleVector} v, {@link DoubleDoubleVectorFunction} f); - * - * This class will be replaced by {@link DoubleDoubleFunction} - */ -@Deprecated -public interface DoubleDoubleVectorFunction { - - /** - * Calculates the result of the left and right value of two vectors at a given - * index. - */ - public double calculate(int index, double left, double right); - -} Index: ml/src/main/java/org/apache/hama/ml/writable/VectorWritable.java =================================================================== --- ml/src/main/java/org/apache/hama/ml/writable/VectorWritable.java (revision 1535330) +++ ml/src/main/java/org/apache/hama/ml/writable/VectorWritable.java (working copy) @@ -1,133 +0,0 @@ -/** - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.hama.ml.writable; - -import java.io.DataInput; -import java.io.DataOutput; -import java.io.IOException; - -import org.apache.hadoop.io.WritableComparable; -import org.apache.hama.ml.math.DenseDoubleVector; -import org.apache.hama.ml.math.DoubleVector; - -/** - * Writable for dense vectors. - */ -public final class VectorWritable implements WritableComparable { - - private DoubleVector vector; - - public VectorWritable() { - super(); - } - - public VectorWritable(VectorWritable v) { - this.vector = v.getVector(); - } - - public VectorWritable(DoubleVector v) { - this.vector = v; - } - - @Override - public final void write(DataOutput out) throws IOException { - writeVector(this.vector, out); - } - - @Override - public final void readFields(DataInput in) throws IOException { - this.vector = readVector(in); - } - - @Override - public final int compareTo(VectorWritable o) { - return compareVector(this, o); - } - - @Override - public int hashCode() { - final int prime = 31; - int result = 1; - result = prime * result + ((vector == null) ? 0 : vector.hashCode()); - return result; - } - - @Override - public boolean equals(Object obj) { - if (this == obj) - return true; - if (obj == null) - return false; - if (getClass() != obj.getClass()) - return false; - VectorWritable other = (VectorWritable) obj; - if (vector == null) { - if (other.vector != null) - return false; - } else if (!vector.equals(other.vector)) - return false; - return true; - } - - /** - * @return the embedded vector - */ - public DoubleVector getVector() { - return vector; - } - - @Override - public String toString() { - return vector.toString(); - } - - public static void writeVector(DoubleVector vector, DataOutput out) - throws IOException { - out.writeInt(vector.getLength()); - for (int i = 0; i < vector.getDimension(); i++) { - out.writeDouble(vector.get(i)); - } - } - - public static DoubleVector readVector(DataInput in) throws IOException { - int length = in.readInt(); - DoubleVector vector; - vector = new DenseDoubleVector(length); - for (int i = 0; i < length; i++) { - vector.set(i, in.readDouble()); - } - return vector; - } - - public static int compareVector(VectorWritable a, VectorWritable o) { - return compareVector(a.getVector(), o.getVector()); - } - - public static int compareVector(DoubleVector a, DoubleVector o) { - DoubleVector subtract = a.subtractUnsafe(o); - return (int) subtract.sum(); - } - - public static VectorWritable wrap(DoubleVector a) { - return new VectorWritable(a); - } - - public void set(DoubleVector vector) { - this.vector = vector; - } -} Index: ml/src/main/java/org/apache/hama/ml/writable/MatrixWritable.java =================================================================== --- ml/src/main/java/org/apache/hama/ml/writable/MatrixWritable.java (revision 1535330) +++ ml/src/main/java/org/apache/hama/ml/writable/MatrixWritable.java (working copy) @@ -1,73 +0,0 @@ -/** - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.hama.ml.writable; - -import java.io.DataInput; -import java.io.DataOutput; -import java.io.IOException; - -import org.apache.hadoop.io.Writable; -import org.apache.hama.ml.math.DenseDoubleMatrix; -import org.apache.hama.ml.math.DoubleMatrix; - -/** - * Majorly designed for dense matrices, can be extended for sparse ones as well. - */ -public final class MatrixWritable implements Writable { - - private DoubleMatrix mat; - - public MatrixWritable() { - } - - public MatrixWritable(DoubleMatrix mat) { - this.mat = mat; - - } - - @Override - public void readFields(DataInput in) throws IOException { - mat = read(in); - } - - @Override - public void write(DataOutput out) throws IOException { - write(mat, out); - } - - public static void write(DoubleMatrix mat, DataOutput out) throws IOException { - out.writeInt(mat.getRowCount()); - out.writeInt(mat.getColumnCount()); - for (int row = 0; row < mat.getRowCount(); row++) { - for (int col = 0; col < mat.getColumnCount(); col++) { - out.writeDouble(mat.get(row, col)); - } - } - } - - public static DoubleMatrix read(DataInput in) throws IOException { - DoubleMatrix mat = new DenseDoubleMatrix(in.readInt(), in.readInt()); - for (int row = 0; row < mat.getRowCount(); row++) { - for (int col = 0; col < mat.getColumnCount(); col++) { - mat.set(row, col, in.readDouble()); - } - } - return mat; - } - -} Index: ml/src/main/java/org/apache/hama/ml/distance/EuclidianDistance.java =================================================================== --- ml/src/main/java/org/apache/hama/ml/distance/EuclidianDistance.java (revision 1535330) +++ ml/src/main/java/org/apache/hama/ml/distance/EuclidianDistance.java (working copy) @@ -17,7 +17,7 @@ */ package org.apache.hama.ml.distance; -import org.apache.hama.ml.math.DoubleVector; +import org.apache.hama.commons.math.DoubleVector; public final class EuclidianDistance implements DistanceMeasurer { Index: ml/src/main/java/org/apache/hama/ml/distance/CosineDistance.java =================================================================== --- ml/src/main/java/org/apache/hama/ml/distance/CosineDistance.java (revision 1535330) +++ ml/src/main/java/org/apache/hama/ml/distance/CosineDistance.java (working copy) @@ -17,7 +17,7 @@ */ package org.apache.hama.ml.distance; -import org.apache.hama.ml.math.DoubleVector; +import org.apache.hama.commons.math.DoubleVector; public final class CosineDistance implements DistanceMeasurer { Index: ml/src/main/java/org/apache/hama/ml/distance/DistanceMeasurer.java =================================================================== --- ml/src/main/java/org/apache/hama/ml/distance/DistanceMeasurer.java (revision 1535330) +++ ml/src/main/java/org/apache/hama/ml/distance/DistanceMeasurer.java (working copy) @@ -17,7 +17,7 @@ */ package org.apache.hama.ml.distance; -import org.apache.hama.ml.math.DoubleVector; +import org.apache.hama.commons.math.DoubleVector; /** * a {@link DistanceMeasurer} is responsible for calculating the distance Index: ml/src/main/java/org/apache/hama/ml/ann/SmallLayeredNeuralNetwork.java =================================================================== --- ml/src/main/java/org/apache/hama/ml/ann/SmallLayeredNeuralNetwork.java (revision 1535330) +++ ml/src/main/java/org/apache/hama/ml/ann/SmallLayeredNeuralNetwork.java (working copy) @@ -32,14 +32,14 @@ import org.apache.hadoop.io.WritableUtils; import org.apache.hama.HamaConfiguration; import org.apache.hama.bsp.BSPJob; -import org.apache.hama.ml.math.DenseDoubleMatrix; -import org.apache.hama.ml.math.DenseDoubleVector; -import org.apache.hama.ml.math.DoubleFunction; -import org.apache.hama.ml.math.DoubleMatrix; -import org.apache.hama.ml.math.DoubleVector; -import org.apache.hama.ml.math.FunctionFactory; -import org.apache.hama.ml.writable.MatrixWritable; -import org.apache.hama.ml.writable.VectorWritable; +import org.apache.hama.commons.io.writable.MatrixWritable; +import org.apache.hama.commons.io.writable.VectorWritable; +import org.apache.hama.commons.math.DenseDoubleMatrix; +import org.apache.hama.commons.math.DenseDoubleVector; +import org.apache.hama.commons.math.DoubleFunction; +import org.apache.hama.commons.math.DoubleMatrix; +import org.apache.hama.commons.math.DoubleVector; +import org.apache.hama.commons.math.FunctionFactory; import org.mortbay.log.Log; import com.google.common.base.Preconditions; Index: ml/src/main/java/org/apache/hama/ml/ann/SmallLayeredNeuralNetworkMessage.java =================================================================== --- ml/src/main/java/org/apache/hama/ml/ann/SmallLayeredNeuralNetworkMessage.java (revision 1535330) +++ ml/src/main/java/org/apache/hama/ml/ann/SmallLayeredNeuralNetworkMessage.java (working copy) @@ -22,9 +22,9 @@ import java.io.IOException; import org.apache.hadoop.io.Writable; -import org.apache.hama.ml.math.DenseDoubleMatrix; -import org.apache.hama.ml.math.DoubleMatrix; -import org.apache.hama.ml.writable.MatrixWritable; +import org.apache.hama.commons.io.writable.MatrixWritable; +import org.apache.hama.commons.math.DenseDoubleMatrix; +import org.apache.hama.commons.math.DoubleMatrix; /** * NeuralNetworkMessage transmits the messages between peers during the training Index: ml/src/main/java/org/apache/hama/ml/ann/SmallLayeredNeuralNetworkTrainer.java =================================================================== --- ml/src/main/java/org/apache/hama/ml/ann/SmallLayeredNeuralNetworkTrainer.java (revision 1535330) +++ ml/src/main/java/org/apache/hama/ml/ann/SmallLayeredNeuralNetworkTrainer.java (working copy) @@ -25,10 +25,10 @@ import org.apache.hama.bsp.BSP; import org.apache.hama.bsp.BSPPeer; import org.apache.hama.bsp.sync.SyncException; -import org.apache.hama.ml.math.DenseDoubleMatrix; -import org.apache.hama.ml.math.DoubleMatrix; -import org.apache.hama.ml.math.DoubleVector; -import org.apache.hama.ml.writable.VectorWritable; +import org.apache.hama.commons.io.writable.VectorWritable; +import org.apache.hama.commons.math.DenseDoubleMatrix; +import org.apache.hama.commons.math.DoubleMatrix; +import org.apache.hama.commons.math.DoubleVector; import org.mortbay.log.Log; /** Index: ml/src/main/java/org/apache/hama/ml/ann/NeuralNetworkTrainer.java =================================================================== --- ml/src/main/java/org/apache/hama/ml/ann/NeuralNetworkTrainer.java (revision 1535330) +++ ml/src/main/java/org/apache/hama/ml/ann/NeuralNetworkTrainer.java (working copy) @@ -27,8 +27,8 @@ import org.apache.hama.bsp.BSP; import org.apache.hama.bsp.BSPPeer; import org.apache.hama.bsp.sync.SyncException; +import org.apache.hama.commons.io.writable.VectorWritable; import org.apache.hama.ml.perception.MLPMessage; -import org.apache.hama.ml.writable.VectorWritable; /** * The trainer that is used to train the {@link SmallLayeredNeuralNetwork} with Index: ml/src/main/java/org/apache/hama/ml/ann/AutoEncoder.java =================================================================== --- ml/src/main/java/org/apache/hama/ml/ann/AutoEncoder.java (revision 1535330) +++ ml/src/main/java/org/apache/hama/ml/ann/AutoEncoder.java (working copy) @@ -21,12 +21,12 @@ import java.util.Map; import org.apache.hadoop.fs.Path; +import org.apache.hama.commons.math.DenseDoubleVector; +import org.apache.hama.commons.math.DoubleFunction; +import org.apache.hama.commons.math.DoubleMatrix; +import org.apache.hama.commons.math.DoubleVector; +import org.apache.hama.commons.math.FunctionFactory; import org.apache.hama.ml.ann.AbstractLayeredNeuralNetwork.LearningStyle; -import org.apache.hama.ml.math.DenseDoubleVector; -import org.apache.hama.ml.math.DoubleFunction; -import org.apache.hama.ml.math.DoubleMatrix; -import org.apache.hama.ml.math.DoubleVector; -import org.apache.hama.ml.math.FunctionFactory; import com.google.common.base.Preconditions; Index: ml/src/main/java/org/apache/hama/ml/ann/AbstractLayeredNeuralNetwork.java =================================================================== --- ml/src/main/java/org/apache/hama/ml/ann/AbstractLayeredNeuralNetwork.java (revision 1535330) +++ ml/src/main/java/org/apache/hama/ml/ann/AbstractLayeredNeuralNetwork.java (working copy) @@ -24,11 +24,11 @@ import java.util.List; import org.apache.hadoop.io.WritableUtils; -import org.apache.hama.ml.math.DoubleDoubleFunction; -import org.apache.hama.ml.math.DoubleFunction; -import org.apache.hama.ml.math.DoubleMatrix; -import org.apache.hama.ml.math.DoubleVector; -import org.apache.hama.ml.math.FunctionFactory; +import org.apache.hama.commons.math.DoubleDoubleFunction; +import org.apache.hama.commons.math.DoubleFunction; +import org.apache.hama.commons.math.DoubleMatrix; +import org.apache.hama.commons.math.DoubleVector; +import org.apache.hama.commons.math.FunctionFactory; import com.google.common.base.Preconditions; Index: ml/src/main/java/org/apache/hama/ml/kmeans/KMeansBSP.java =================================================================== --- ml/src/main/java/org/apache/hama/ml/kmeans/KMeansBSP.java (revision 1535330) +++ ml/src/main/java/org/apache/hama/ml/kmeans/KMeansBSP.java (working copy) @@ -40,11 +40,11 @@ import org.apache.hama.bsp.BSPJob; import org.apache.hama.bsp.BSPPeer; import org.apache.hama.bsp.sync.SyncException; +import org.apache.hama.commons.io.writable.VectorWritable; +import org.apache.hama.commons.math.DenseDoubleVector; +import org.apache.hama.commons.math.DoubleVector; import org.apache.hama.ml.distance.DistanceMeasurer; import org.apache.hama.ml.distance.EuclidianDistance; -import org.apache.hama.ml.math.DenseDoubleVector; -import org.apache.hama.ml.math.DoubleVector; -import org.apache.hama.ml.writable.VectorWritable; import org.apache.hama.util.ReflectionUtils; import com.google.common.base.Preconditions; Index: ml/src/main/java/org/apache/hama/ml/kmeans/CenterMessage.java =================================================================== --- ml/src/main/java/org/apache/hama/ml/kmeans/CenterMessage.java (revision 1535330) +++ ml/src/main/java/org/apache/hama/ml/kmeans/CenterMessage.java (working copy) @@ -22,8 +22,8 @@ import java.io.IOException; import org.apache.hadoop.io.Writable; -import org.apache.hama.ml.math.DoubleVector; -import org.apache.hama.ml.writable.VectorWritable; +import org.apache.hama.commons.io.writable.VectorWritable; +import org.apache.hama.commons.math.DoubleVector; public final class CenterMessage implements Writable { Index: ml/src/main/java/org/apache/hama/ml/regression/RegressionModel.java =================================================================== --- ml/src/main/java/org/apache/hama/ml/regression/RegressionModel.java (revision 1535330) +++ ml/src/main/java/org/apache/hama/ml/regression/RegressionModel.java (working copy) @@ -19,7 +19,7 @@ import java.math.BigDecimal; -import org.apache.hama.ml.math.DoubleVector; +import org.apache.hama.commons.math.DoubleVector; /** * A cost model for gradient descent based regression Index: ml/src/main/java/org/apache/hama/ml/regression/HypothesisFunction.java =================================================================== --- ml/src/main/java/org/apache/hama/ml/regression/HypothesisFunction.java (revision 1535330) +++ ml/src/main/java/org/apache/hama/ml/regression/HypothesisFunction.java (working copy) @@ -19,7 +19,7 @@ import java.math.BigDecimal; -import org.apache.hama.ml.math.DoubleVector; +import org.apache.hama.commons.math.DoubleVector; /** * The mathematical model chosen for a specific learning problem Index: ml/src/main/java/org/apache/hama/ml/regression/VectorDoubleFileInputFormat.java =================================================================== --- ml/src/main/java/org/apache/hama/ml/regression/VectorDoubleFileInputFormat.java (revision 1535330) +++ ml/src/main/java/org/apache/hama/ml/regression/VectorDoubleFileInputFormat.java (working copy) @@ -17,6 +17,9 @@ */ package org.apache.hama.ml.regression; +import java.io.IOException; +import java.io.InputStream; + import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; import org.apache.hadoop.conf.Configuration; @@ -27,14 +30,15 @@ import org.apache.hadoop.io.Text; import org.apache.hadoop.io.compress.CompressionCodec; import org.apache.hadoop.io.compress.CompressionCodecFactory; -import org.apache.hama.bsp.*; -import org.apache.hama.ml.math.DenseDoubleVector; -import org.apache.hama.ml.math.DoubleVector; -import org.apache.hama.ml.writable.VectorWritable; +import org.apache.hama.bsp.BSPJob; +import org.apache.hama.bsp.FileInputFormat; +import org.apache.hama.bsp.FileSplit; +import org.apache.hama.bsp.InputSplit; +import org.apache.hama.bsp.RecordReader; +import org.apache.hama.commons.io.writable.VectorWritable; +import org.apache.hama.commons.math.DenseDoubleVector; +import org.apache.hama.commons.math.DoubleVector; -import java.io.IOException; -import java.io.InputStream; - /** * A {@link FileInputFormat} for files containing one vector and one double per * line Index: ml/src/main/java/org/apache/hama/ml/regression/GradientDescentBSP.java =================================================================== --- ml/src/main/java/org/apache/hama/ml/regression/GradientDescentBSP.java (revision 1535330) +++ ml/src/main/java/org/apache/hama/ml/regression/GradientDescentBSP.java (working copy) @@ -25,10 +25,10 @@ import org.apache.hama.bsp.BSP; import org.apache.hama.bsp.BSPPeer; import org.apache.hama.bsp.sync.SyncException; -import org.apache.hama.ml.math.DenseDoubleVector; -import org.apache.hama.ml.math.DoubleVector; -import org.apache.hama.ml.writable.VectorWritable; -import org.apache.hama.util.KeyValuePair; +import org.apache.hama.commons.io.writable.VectorWritable; +import org.apache.hama.commons.math.DenseDoubleVector; +import org.apache.hama.commons.math.DoubleVector; +import org.apache.hama.commons.util.KeyValuePair; import org.slf4j.Logger; import org.slf4j.LoggerFactory; Index: ml/src/main/java/org/apache/hama/ml/regression/LinearRegression.java =================================================================== --- ml/src/main/java/org/apache/hama/ml/regression/LinearRegression.java (revision 1535330) +++ ml/src/main/java/org/apache/hama/ml/regression/LinearRegression.java (working copy) @@ -22,10 +22,10 @@ import java.util.Map; import org.apache.hadoop.fs.Path; +import org.apache.hama.commons.math.DoubleMatrix; +import org.apache.hama.commons.math.DoubleVector; +import org.apache.hama.commons.math.FunctionFactory; import org.apache.hama.ml.ann.SmallLayeredNeuralNetwork; -import org.apache.hama.ml.math.DoubleMatrix; -import org.apache.hama.ml.math.DoubleVector; -import org.apache.hama.ml.math.FunctionFactory; /** * Linear regression model. It can be used for numeric regression or prediction. Index: ml/src/main/java/org/apache/hama/ml/regression/LinearRegressionModel.java =================================================================== --- ml/src/main/java/org/apache/hama/ml/regression/LinearRegressionModel.java (revision 1535330) +++ ml/src/main/java/org/apache/hama/ml/regression/LinearRegressionModel.java (working copy) @@ -19,7 +19,7 @@ import java.math.BigDecimal; -import org.apache.hama.ml.math.DoubleVector; +import org.apache.hama.commons.math.DoubleVector; /** * A {@link RegressionModel} for linear regression Index: ml/src/main/java/org/apache/hama/ml/regression/CostFunction.java =================================================================== --- ml/src/main/java/org/apache/hama/ml/regression/CostFunction.java (revision 1535330) +++ ml/src/main/java/org/apache/hama/ml/regression/CostFunction.java (working copy) @@ -19,7 +19,7 @@ import java.math.BigDecimal; -import org.apache.hama.ml.math.DoubleVector; +import org.apache.hama.commons.math.DoubleVector; /** * An optimization (minimization) problem's cost function Index: ml/src/main/java/org/apache/hama/ml/regression/LogisticRegression.java =================================================================== --- ml/src/main/java/org/apache/hama/ml/regression/LogisticRegression.java (revision 1535330) +++ ml/src/main/java/org/apache/hama/ml/regression/LogisticRegression.java (working copy) @@ -22,9 +22,9 @@ import java.util.Map; import org.apache.hadoop.fs.Path; +import org.apache.hama.commons.math.DoubleVector; +import org.apache.hama.commons.math.FunctionFactory; import org.apache.hama.ml.ann.SmallLayeredNeuralNetwork; -import org.apache.hama.ml.math.DoubleVector; -import org.apache.hama.ml.math.FunctionFactory; /** * The logistic regression model. It can be used to conduct 2-class Index: ml/src/main/java/org/apache/hama/ml/regression/LogisticRegressionModel.java =================================================================== --- ml/src/main/java/org/apache/hama/ml/regression/LogisticRegressionModel.java (revision 1535330) +++ ml/src/main/java/org/apache/hama/ml/regression/LogisticRegressionModel.java (working copy) @@ -20,7 +20,7 @@ import java.math.BigDecimal; import java.math.MathContext; -import org.apache.hama.ml.math.DoubleVector; +import org.apache.hama.commons.math.DoubleVector; /** * A {@link RegressionModel} for logistic regression Index: ml/pom.xml =================================================================== --- ml/pom.xml (revision 1535330) +++ ml/pom.xml (working copy) @@ -33,6 +33,11 @@ org.apache.hama + hama-commons + ${project.version} + + + org.apache.hama hama-core ${project.version} @@ -42,5 +47,5 @@ ${project.version}
- + Index: commons/src/test/java/org/apache/hama/commons/math/TestDenseDoubleVector.java =================================================================== --- commons/src/test/java/org/apache/hama/commons/math/TestDenseDoubleVector.java (revision 0) +++ commons/src/test/java/org/apache/hama/commons/math/TestDenseDoubleVector.java (revision 0) @@ -0,0 +1,208 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.hama.commons.math; + +import static org.junit.Assert.assertArrayEquals; +import static org.junit.Assert.assertEquals; + +import org.junit.Test; + +/** + * Testcase for {@link DenseDoubleVector} + * + */ +public class TestDenseDoubleVector { + + @Test + public void testApplyDoubleFunction() { + double[] values = new double[] {1, 2, 3, 4, 5}; + double[] result = new double[] {2, 3, 4, 5, 6}; + + DoubleVector vec1 = new DenseDoubleVector(values); + + vec1.applyToElements(new DoubleFunction() { + + @Override + public double apply(double value) { + return value + 1; + } + + @Override + public double applyDerivative(double value) { + throw new UnsupportedOperationException("Not supported."); + } + + }); + + assertArrayEquals(result, vec1.toArray(), 0.0001); + } + + @Test + public void testApplyDoubleDoubleFunction() { + double[] values1 = new double[] {1, 2, 3, 4, 5, 6}; + double[] values2 = new double[] {7, 8, 9, 10, 11, 12}; + double[] result = new double[] {8, 10, 12, 14, 16, 18}; + + DoubleVector vec1 = new DenseDoubleVector(values1); + DoubleVector vec2 = new DenseDoubleVector(values2); + + vec1.applyToElements(vec2, new DoubleDoubleFunction() { + + @Override + public double apply(double x1, double x2) { + return x1 + x2; + } + + @Override + public double applyDerivative(double x1, double x2) { + throw new UnsupportedOperationException("Not supported"); + } + + }); + + assertArrayEquals(result, vec1.toArray(), 0.0001); + + } + + @Test + public void testAddNormal() { + double[] arr1 = new double[] {1, 2, 3}; + double[] arr2 = new double[] {4, 5, 6}; + DoubleVector vec1 = new DenseDoubleVector(arr1); + DoubleVector vec2 = new DenseDoubleVector(arr2); + double[] arrExp = new double[] {5, 7, 9}; + assertArrayEquals(arrExp, vec1.add(vec2).toArray(), 0.000001); + } + + @Test(expected = IllegalArgumentException.class) + public void testAddAbnormal() { + double[] arr1 = new double[] {1, 2, 3}; + double[] arr2 = new double[] {4, 5}; + DoubleVector vec1 = new DenseDoubleVector(arr1); + DoubleVector vec2 = new DenseDoubleVector(arr2); + vec1.add(vec2); + } + + @Test + public void testSubtractNormal() { + double[] arr1 = new double[] {1, 2, 3}; + double[] arr2 = new double[] {4, 5, 6}; + DoubleVector vec1 = new DenseDoubleVector(arr1); + DoubleVector vec2 = new DenseDoubleVector(arr2); + double[] arrExp = new double[] {-3, -3, -3}; + assertArrayEquals(arrExp, vec1.subtract(vec2).toArray(), 0.000001); + } + + @Test(expected = IllegalArgumentException.class) + public void testSubtractAbnormal() { + double[] arr1 = new double[] {1, 2, 3}; + double[] arr2 = new double[] {4, 5}; + DoubleVector vec1 = new DenseDoubleVector(arr1); + DoubleVector vec2 = new DenseDoubleVector(arr2); + vec1.subtract(vec2); + } + + @Test + public void testMultiplyNormal() { + double[] arr1 = new double[] {1, 2, 3}; + double[] arr2 = new double[] {4, 5, 6}; + DoubleVector vec1 = new DenseDoubleVector(arr1); + DoubleVector vec2 = new DenseDoubleVector(arr2); + double[] arrExp = new double[] {4, 10, 18}; + assertArrayEquals(arrExp, vec1.multiply(vec2).toArray(), 0.000001); + } + + @Test(expected = IllegalArgumentException.class) + public void testMultiplyAbnormal() { + double[] arr1 = new double[] {1, 2, 3}; + double[] arr2 = new double[] {4, 5}; + DoubleVector vec1 = new DenseDoubleVector(arr1); + DoubleVector vec2 = new DenseDoubleVector(arr2); + vec1.multiply(vec2); + } + + @Test + public void testDotNormal() { + double[] arr1 = new double[] {1, 2, 3}; + double[] arr2 = new double[] {4, 5, 6}; + DoubleVector vec1 = new DenseDoubleVector(arr1); + DoubleVector vec2 = new DenseDoubleVector(arr2); + assertEquals(32.0, vec1.dot(vec2), 0.000001); + } + + @Test(expected = IllegalArgumentException.class) + public void testDotAbnormal() { + double[] arr1 = new double[] {1, 2, 3}; + double[] arr2 = new double[] {4, 5}; + DoubleVector vec1 = new DenseDoubleVector(arr1); + DoubleVector vec2 = new DenseDoubleVector(arr2); + vec1.add(vec2); + } + + @Test + public void testSliceNormal() { + double[] arr1 = new double[] {2, 3, 4, 5, 6}; + double[] arr2 = new double[] {4, 5, 6}; + double[] arr3 = new double[] {2, 3, 4}; + DoubleVector vec1 = new DenseDoubleVector(arr1); + assertArrayEquals(arr2, vec1.slice(2, 4).toArray(), 0.000001); + DoubleVector vec2 = new DenseDoubleVector(arr1); + assertArrayEquals(arr3, vec2.slice(3).toArray(), 0.000001); + } + + @Test(expected = IllegalArgumentException.class) + public void testSliceAbnormal() { + double[] arr1 = new double[] {2, 3, 4, 5, 6}; + DoubleVector vec = new DenseDoubleVector(arr1); + vec.slice(2, 5); + } + + @Test(expected = IllegalArgumentException.class) + public void testSliceAbnormalEndTooLarge() { + double[] arr1 = new double[] {2, 3, 4, 5, 6}; + DoubleVector vec = new DenseDoubleVector(arr1); + vec.slice(2, 5); + } + + @Test(expected = IllegalArgumentException.class) + public void testSliceAbnormalStartLargerThanEnd() { + double[] arr1 = new double[] {2, 3, 4, 5, 6}; + DoubleVector vec = new DenseDoubleVector(arr1); + vec.slice(4, 3); + } + + @Test + public void testVectorMultiplyMatrix() { + DoubleVector vec = new DenseDoubleVector(new double[]{1, 2, 3}); + DoubleMatrix mat = new DenseDoubleMatrix(new double[][] { + {1, 2, 3, 4}, {5, 6, 7, 8}, {9, 10, 11, 12} + }); + double[] expectedRes = new double[] {38, 44, 50, 56}; + + assertArrayEquals(expectedRes, vec.multiply(mat).toArray(), 0.000001); + } + + @Test(expected = IllegalArgumentException.class) + public void testVectorMultiplyMatrixAbnormal() { + DoubleVector vec = new DenseDoubleVector(new double[]{1, 2, 3}); + DoubleMatrix mat = new DenseDoubleMatrix(new double[][] { + {1, 2, 3, 4}, {5, 6, 7, 8}, {9, 10, 11, 12}, {13, 14, 15, 16} + }); + vec.multiply(mat); + } +} Index: commons/src/test/java/org/apache/hama/commons/math/TestFunctionFactory.java =================================================================== --- commons/src/test/java/org/apache/hama/commons/math/TestFunctionFactory.java (revision 0) +++ commons/src/test/java/org/apache/hama/commons/math/TestFunctionFactory.java (revision 0) @@ -0,0 +1,82 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.hama.commons.math; + +import static org.junit.Assert.assertEquals; + +import java.util.Random; + +import org.junit.Test; + +/** + * Test case for {@link FunctionFactory} + * + */ +public class TestFunctionFactory { + + @Test + public void testCreateDoubleFunction() { + double input = 0.8; + + String sigmoidName = "Sigmoid"; + DoubleFunction sigmoidFunction = FunctionFactory + .createDoubleFunction(sigmoidName); + assertEquals(sigmoidName, sigmoidFunction.getFunctionName()); + + double sigmoidExcepted = 0.68997448; + assertEquals(sigmoidExcepted, sigmoidFunction.apply(input), 0.000001); + + + String tanhName = "Tanh"; + DoubleFunction tanhFunction = FunctionFactory.createDoubleFunction(tanhName); + assertEquals(tanhName, tanhFunction.getFunctionName()); + + double tanhExpected = 0.66403677; + assertEquals(tanhExpected, tanhFunction.apply(input), 0.00001); + + + String identityFunctionName = "IdentityFunction"; + DoubleFunction identityFunction = FunctionFactory.createDoubleFunction(identityFunctionName); + + Random rnd = new Random(); + double identityExpected = rnd.nextDouble(); + assertEquals(identityExpected, identityFunction.apply(identityExpected), 0.000001); + } + + @Test + public void testCreateDoubleDoubleFunction() { + double target = 0.5; + double output = 0.8; + + String squaredErrorName = "SquaredError"; + DoubleDoubleFunction squaredErrorFunction = FunctionFactory.createDoubleDoubleFunction(squaredErrorName); + assertEquals(squaredErrorName, squaredErrorFunction.getFunctionName()); + + double squaredErrorExpected = 0.045; + + assertEquals(squaredErrorExpected, squaredErrorFunction.apply(target, output), 0.000001); + + String crossEntropyName = "CrossEntropy"; + DoubleDoubleFunction crossEntropyFunction = FunctionFactory.createDoubleDoubleFunction(crossEntropyName); + assertEquals(crossEntropyName, crossEntropyFunction.getFunctionName()); + + double crossEntropyExpected = 0.91629; + assertEquals(crossEntropyExpected, crossEntropyFunction.apply(target, output), 0.000001); + } + +} Index: commons/src/test/java/org/apache/hama/commons/math/TestDenseDoubleMatrix.java =================================================================== --- commons/src/test/java/org/apache/hama/commons/math/TestDenseDoubleMatrix.java (revision 0) +++ commons/src/test/java/org/apache/hama/commons/math/TestDenseDoubleMatrix.java (revision 0) @@ -0,0 +1,237 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.hama.commons.math; + +import static org.junit.Assert.assertArrayEquals; + +import org.junit.Test; + +/** + * Test case for {@link DenseDoubleMatrix} + * + */ +public class TestDenseDoubleMatrix { + + @Test + public void testDoubleFunction() { + double[][] values = new double[][] { { 1, 2, 3 }, { 4, 5, 6 }, { 7, 8, 9 } }; + + double[][] result = new double[][] { { 2, 3, 4 }, { 5, 6, 7 }, { 8, 9, 10 } }; + + DenseDoubleMatrix mat = new DenseDoubleMatrix(values); + mat.applyToElements(new DoubleFunction() { + + @Override + public double apply(double value) { + return value + 1; + } + + @Override + public double applyDerivative(double value) { + throw new UnsupportedOperationException(); + } + + }); + + double[][] actual = mat.getValues(); + for (int i = 0; i < actual.length; ++i) { + assertArrayEquals(result[i], actual[i], 0.0001); + } + } + + @Test + public void testDoubleDoubleFunction() { + double[][] values1 = new double[][] { { 1, 2, 3 }, { 4, 5, 6 }, { 7, 8, 9 } }; + double[][] values2 = new double[][] { { 2, 3, 4 }, { 5, 6, 7 }, + { 8, 9, 10 } }; + double[][] result = new double[][] { { 3, 5, 7 }, { 9, 11, 13 }, + { 15, 17, 19 } }; + + DenseDoubleMatrix mat1 = new DenseDoubleMatrix(values1); + DenseDoubleMatrix mat2 = new DenseDoubleMatrix(values2); + + mat1.applyToElements(mat2, new DoubleDoubleFunction() { + + @Override + public double apply(double x1, double x2) { + return x1 + x2; + } + + @Override + public double applyDerivative(double x1, double x2) { + throw new UnsupportedOperationException(); + } + + }); + + double[][] actual = mat1.getValues(); + for (int i = 0; i < actual.length; ++i) { + assertArrayEquals(result[i], actual[i], 0.0001); + } + } + + @Test + public void testMultiplyNormal() { + double[][] mat1 = new double[][] { { 1, 2, 3 }, { 4, 5, 6 } }; + double[][] mat2 = new double[][] { { 6, 5 }, { 4, 3 }, { 2, 1 } }; + double[][] expMat = new double[][] { { 20, 14 }, { 56, 41 } }; + DoubleMatrix matrix1 = new DenseDoubleMatrix(mat1); + DoubleMatrix matrix2 = new DenseDoubleMatrix(mat2); + DoubleMatrix actMatrix = matrix1.multiply(matrix2); + for (int r = 0; r < actMatrix.getRowCount(); ++r) { + assertArrayEquals(expMat[r], actMatrix.getRowVector(r).toArray(), + 0.000001); + } + } + + @Test(expected = IllegalArgumentException.class) + public void testMultiplyAbnormal() { + double[][] mat1 = new double[][] { { 1, 2, 3 }, { 4, 5, 6 } }; + double[][] mat2 = new double[][] { { 6, 5 }, { 4, 3 } }; + DoubleMatrix matrix1 = new DenseDoubleMatrix(mat1); + DoubleMatrix matrix2 = new DenseDoubleMatrix(mat2); + matrix1.multiply(matrix2); + } + + @Test + public void testMultiplyElementWiseNormal() { + double[][] mat1 = new double[][] { { 1, 2, 3 }, { 4, 5, 6 } }; + double[][] mat2 = new double[][] { { 6, 5, 4 }, { 3, 2, 1 } }; + double[][] expMat = new double[][] { { 6, 10, 12 }, { 12, 10, 6 } }; + DoubleMatrix matrix1 = new DenseDoubleMatrix(mat1); + DoubleMatrix matrix2 = new DenseDoubleMatrix(mat2); + DoubleMatrix actMatrix = matrix1.multiplyElementWise(matrix2); + for (int r = 0; r < actMatrix.getRowCount(); ++r) { + assertArrayEquals(expMat[r], actMatrix.getRowVector(r).toArray(), + 0.000001); + } + } + + @Test(expected = IllegalArgumentException.class) + public void testMultiplyElementWiseAbnormal() { + double[][] mat1 = new double[][] { { 1, 2, 3 }, { 4, 5, 6 } }; + double[][] mat2 = new double[][] { { 6, 5 }, { 4, 3 } }; + DoubleMatrix matrix1 = new DenseDoubleMatrix(mat1); + DoubleMatrix matrix2 = new DenseDoubleMatrix(mat2); + matrix1.multiplyElementWise(matrix2); + } + + @Test + public void testMultiplyVectorNormal() { + double[][] mat1 = new double[][] { { 1, 2, 3 }, { 4, 5, 6 } }; + double[] mat2 = new double[] { 6, 5, 4 }; + double[] expVec = new double[] { 28, 73 }; + DoubleMatrix matrix1 = new DenseDoubleMatrix(mat1); + DoubleVector vector2 = new DenseDoubleVector(mat2); + DoubleVector actVec = matrix1.multiplyVector(vector2); + assertArrayEquals(expVec, actVec.toArray(), 0.000001); + } + + @Test(expected = IllegalArgumentException.class) + public void testMultiplyVectorAbnormal() { + double[][] mat1 = new double[][] { { 1, 2, 3 }, { 4, 5, 6 } }; + double[] vec2 = new double[] { 6, 5 }; + DoubleMatrix matrix1 = new DenseDoubleMatrix(mat1); + DoubleVector vector2 = new DenseDoubleVector(vec2); + matrix1.multiplyVector(vector2); + } + + @Test + public void testSubtractNormal() { + double[][] mat1 = new double[][] { + {1, 2, 3}, + {4, 5, 6} + }; + double[][] mat2 = new double[][] { + {6, 5, 4}, + {3, 2, 1} + }; + double[][] expMat = new double[][] { + {-5, -3, -1}, + {1, 3, 5} + }; + DoubleMatrix matrix1 = new DenseDoubleMatrix(mat1); + DoubleMatrix matrix2 = new DenseDoubleMatrix(mat2); + DoubleMatrix actMatrix = matrix1.subtract(matrix2); + for (int r = 0; r < actMatrix.getRowCount(); ++r) { + assertArrayEquals(expMat[r], actMatrix.getRowVector(r).toArray(), 0.000001); + } + } + + @Test(expected = IllegalArgumentException.class) + public void testSubtractAbnormal() { + double[][] mat1 = new double[][] { + {1, 2, 3}, + {4, 5, 6} + }; + double[][] mat2 = new double[][] { + {6, 5}, + {4, 3} + }; + DoubleMatrix matrix1 = new DenseDoubleMatrix(mat1); + DoubleMatrix matrix2 = new DenseDoubleMatrix(mat2); + matrix1.subtract(matrix2); + } + + @Test + public void testDivideVectorNormal() { + double[][] mat1 = new double[][] { { 1, 2, 3 }, { 4, 5, 6 } }; + double[] mat2 = new double[] { 6, 5, 4 }; + double[][] expVec = new double[][] { {1.0 / 6, 2.0 / 5, 3.0 / 4}, {4.0 / 6, 5.0 / 5, 6.0 / 4} }; + DoubleMatrix matrix1 = new DenseDoubleMatrix(mat1); + DoubleVector vector2 = new DenseDoubleVector(mat2); + DoubleMatrix expMat = new DenseDoubleMatrix(expVec); + DoubleMatrix actMat = matrix1.divide(vector2); + for (int r = 0; r < actMat.getRowCount(); ++r) { + assertArrayEquals(expMat.getRowVector(r).toArray(), actMat.getRowVector(r).toArray(), 0.000001); + } + } + + @Test(expected = IllegalArgumentException.class) + public void testDivideVectorAbnormal() { + double[][] mat1 = new double[][] { { 1, 2, 3 }, { 4, 5, 6 } }; + double[] vec2 = new double[] { 6, 5 }; + DoubleMatrix matrix1 = new DenseDoubleMatrix(mat1); + DoubleVector vector2 = new DenseDoubleVector(vec2); + matrix1.divide(vector2); + } + + @Test + public void testDivideNormal() { + double[][] mat1 = new double[][] { { 1, 2, 3 }, { 4, 5, 6 } }; + double[][] mat2 = new double[][] { { 6, 5, 4 }, { 3, 2, 1 } }; + double[][] expMat = new double[][] { { 1.0 / 6, 2.0 / 5, 3.0 / 4 }, { 4.0 / 3, 5.0 / 2, 6.0 / 1 } }; + DoubleMatrix matrix1 = new DenseDoubleMatrix(mat1); + DoubleMatrix matrix2 = new DenseDoubleMatrix(mat2); + DoubleMatrix actMatrix = matrix1.divide(matrix2); + for (int r = 0; r < actMatrix.getRowCount(); ++r) { + assertArrayEquals(expMat[r], actMatrix.getRowVector(r).toArray(), + 0.000001); + } + } + + @Test(expected = IllegalArgumentException.class) + public void testDivideAbnormal() { + double[][] mat1 = new double[][] { { 1, 2, 3 }, { 4, 5, 6 } }; + double[][] mat2 = new double[][] { { 6, 5 }, { 4, 3 } }; + DoubleMatrix matrix1 = new DenseDoubleMatrix(mat1); + DoubleMatrix matrix2 = new DenseDoubleMatrix(mat2); + matrix1.divide(matrix2); + } + +} Index: commons/src/main/java/org/apache/hama/commons/io/SocketIOWithTimeout.java =================================================================== --- commons/src/main/java/org/apache/hama/commons/io/SocketIOWithTimeout.java (revision 0) +++ commons/src/main/java/org/apache/hama/commons/io/SocketIOWithTimeout.java (revision 0) @@ -0,0 +1,453 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.hama.commons.io; + +import java.io.IOException; +import java.io.InterruptedIOException; +import java.net.SocketAddress; +import java.net.SocketTimeoutException; +import java.nio.ByteBuffer; +import java.nio.channels.SelectableChannel; +import java.nio.channels.SelectionKey; +import java.nio.channels.Selector; +import java.nio.channels.SocketChannel; +import java.nio.channels.spi.SelectorProvider; +import java.util.Iterator; +import java.util.LinkedList; + +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; +import org.apache.hadoop.util.StringUtils; + +/** + * This supports input and output streams for a socket channels. These streams + * can have a timeout. + */ +public abstract class SocketIOWithTimeout { + // This is intentionally package private. + + static final Log LOG = LogFactory.getLog(SocketIOWithTimeout.class); + + private SelectableChannel channel; + private long timeout; + private boolean closed = false; + + private static SelectorPool selector = new SelectorPool(); + + /* + * A timeout value of 0 implies wait for ever. We should have a value of + * timeout that implies zero wait.. i.e. read or write returns immediately. + * This will set channel to non-blocking. + */ + public SocketIOWithTimeout(SelectableChannel channel, long timeout) + throws IOException { + checkChannelValidity(channel); + + this.channel = channel; + this.timeout = timeout; + // Set non-blocking + channel.configureBlocking(false); + } + + public void close() { + closed = true; + } + + public boolean isOpen() { + return !closed && channel.isOpen(); + } + + public SelectableChannel getChannel() { + return channel; + } + + /** + * Utility function to check if channel is ok. Mainly to throw IOException + * instead of runtime exception in case of mismatch. This mismatch can occur + * for many runtime reasons. + */ + public static void checkChannelValidity(Object channel) throws IOException { + if (channel == null) { + /* + * Most common reason is that original socket does not have a channel. So + * making this an IOException rather than a RuntimeException. + */ + throw new IOException("Channel is null. Check " + + "how the channel or socket is created."); + } + + if (!(channel instanceof SelectableChannel)) { + throw new IOException("Channel should be a SelectableChannel"); + } + } + + /** + * Performs actual IO operations. This is not expected to block. + * + * @param buf + * @return number of bytes (or some equivalent). 0 implies underlying channel + * is drained completely. We will wait if more IO is required. + * @throws IOException + */ + abstract int performIO(ByteBuffer buf) throws IOException; + + /** + * Performs one IO and returns number of bytes read or written. It waits up to + * the specified timeout. If the channel is not read before the timeout, + * SocketTimeoutException is thrown. + * + * @param buf buffer for IO + * @param ops Selection Ops used for waiting. Suggested values: + * SelectionKey.OP_READ while reading and SelectionKey.OP_WRITE while + * writing. + * + * @return number of bytes read or written. negative implies end of stream. + * @throws IOException + */ + int doIO(ByteBuffer buf, int ops) throws IOException { + + /* + * For now only one thread is allowed. If user want to read or write from + * multiple threads, multiple streams could be created. In that case + * multiple threads work as well as underlying channel supports it. + */ + if (!buf.hasRemaining()) { + throw new IllegalArgumentException("Buffer has no data left."); + // or should we just return 0? + } + + while (buf.hasRemaining()) { + if (closed) { + return -1; + } + + try { + int n = performIO(buf); + if (n != 0) { + // successful io or an error. + return n; + } + } catch (IOException e) { + if (!channel.isOpen()) { + closed = true; + } + throw e; + } + + // now wait for socket to be ready. + int count = 0; + try { + count = selector.select(channel, ops, timeout); + } catch (IOException e) { // unexpected IOException. + closed = true; + throw e; + } + + if (count == 0) { + throw new SocketTimeoutException(timeoutExceptionString(channel, + timeout, ops)); + } + // otherwise the socket should be ready for io. + } + + return 0; // does not reach here. + } + + /** + * The contract is similar to {@link SocketChannel#connect(SocketAddress)} + * with a timeout. + * + * @see SocketChannel#connect(SocketAddress) + * + * @param channel - this should be a {@link SelectableChannel} + * @param endpoint + * @throws IOException + */ + public static void connect(SocketChannel channel, SocketAddress endpoint, + int timeout) throws IOException { + + boolean blockingOn = channel.isBlocking(); + if (blockingOn) { + channel.configureBlocking(false); + } + + try { + if (channel.connect(endpoint)) { + return; + } + + long timeoutLeft = timeout; + long endTime = (timeout > 0) ? (System.currentTimeMillis() + timeout) : 0; + + while (true) { + // we might have to call finishConnect() more than once + // for some channels (with user level protocols) + + int ret = selector.select((SelectableChannel) channel, + SelectionKey.OP_CONNECT, timeoutLeft); + + if (ret > 0 && channel.finishConnect()) { + return; + } + + if (ret == 0 + || (timeout > 0 && (timeoutLeft = (endTime - System + .currentTimeMillis())) <= 0)) { + throw new SocketTimeoutException(timeoutExceptionString(channel, + timeout, SelectionKey.OP_CONNECT)); + } + } + } catch (IOException e) { + // javadoc for SocketChannel.connect() says channel should be closed. + try { + channel.close(); + } catch (IOException ignored) { + } + throw e; + } finally { + if (blockingOn && channel.isOpen()) { + channel.configureBlocking(true); + } + } + } + + /** + * This is similar to {@link #doIO(ByteBuffer, int)} except that it does not + * perform any I/O. It just waits for the channel to be ready for I/O as + * specified in ops. + * + * @param ops Selection Ops used for waiting + * + * @throws SocketTimeoutException if select on the channel times out. + * @throws IOException if any other I/O error occurs. + */ + public void waitForIO(int ops) throws IOException { + + if (selector.select(channel, ops, timeout) == 0) { + throw new SocketTimeoutException(timeoutExceptionString(channel, timeout, + ops)); + } + } + + private static String timeoutExceptionString(SelectableChannel channel, + long timeout, int ops) { + + String waitingFor; + switch (ops) { + + case SelectionKey.OP_READ: + waitingFor = "read"; + break; + + case SelectionKey.OP_WRITE: + waitingFor = "write"; + break; + + case SelectionKey.OP_CONNECT: + waitingFor = "connect"; + break; + + default: + waitingFor = "" + ops; + } + + return timeout + " millis timeout while " + + "waiting for channel to be ready for " + waitingFor + ". ch : " + + channel; + } + + /** + * This maintains a pool of selectors. These selectors are closed once they + * are idle (unused) for a few seconds. + */ + private static class SelectorPool { + + private static class SelectorInfo { + Selector selector; + long lastActivityTime; + LinkedList queue; + + void close() { + if (selector != null) { + try { + selector.close(); + } catch (IOException e) { + LOG.warn("Unexpected exception while closing selector : " + + StringUtils.stringifyException(e)); + } + } + } + } + + private static class ProviderInfo { + SelectorProvider provider; + LinkedList queue; // lifo + ProviderInfo next; + } + + private static final long IDLE_TIMEOUT = 10 * 1000; // 10 seconds. + + private ProviderInfo providerList = null; + + /** + * Waits on the channel with the given timeout using one of the cached + * selectors. It also removes any cached selectors that are idle for a few + * seconds. + * + * @param channel + * @param ops + * @param timeout + * @return + * @throws IOException + */ + int select(SelectableChannel channel, int ops, long timeout) + throws IOException { + + SelectorInfo info = get(channel); + + SelectionKey key = null; + int ret = 0; + + try { + while (true) { + long start = (timeout == 0) ? 0 : System.currentTimeMillis(); + + key = channel.register(info.selector, ops); + ret = info.selector.select(timeout); + + if (ret != 0) { + return ret; + } + + /* + * Sometimes select() returns 0 much before timeout for unknown + * reasons. So select again if required. + */ + if (timeout > 0) { + timeout -= System.currentTimeMillis() - start; + if (timeout <= 0) { + return 0; + } + } + + if (Thread.currentThread().isInterrupted()) { + throw new InterruptedIOException("Interruped while waiting for " + + "IO on channel " + channel + ". " + timeout + + " millis timeout left."); + } + } + } finally { + if (key != null) { + key.cancel(); + } + + // clear the canceled key. + try { + info.selector.selectNow(); + } catch (IOException e) { + LOG.info("Unexpected Exception while clearing selector : " + + StringUtils.stringifyException(e)); + // don't put the selector back. + info.close(); + return ret; + } + + release(info); + } + } + + /** + * Takes one selector from end of LRU list of free selectors. If there are + * no selectors awailable, it creates a new selector. Also invokes + * trimIdleSelectors(). + * + * @param channel + * @return + * @throws IOException + */ + private synchronized SelectorInfo get(SelectableChannel channel) + throws IOException { + SelectorInfo selInfo = null; + + SelectorProvider provider = channel.provider(); + + // pick the list : rarely there is more than one provider in use. + ProviderInfo pList = providerList; + while (pList != null && pList.provider != provider) { + pList = pList.next; + } + if (pList == null) { + // LOG.info("Creating new ProviderInfo : " + provider.toString()); + pList = new ProviderInfo(); + pList.provider = provider; + pList.queue = new LinkedList(); + pList.next = providerList; + providerList = pList; + } + + LinkedList queue = pList.queue; + + if (queue.isEmpty()) { + Selector selector = provider.openSelector(); + selInfo = new SelectorInfo(); + selInfo.selector = selector; + selInfo.queue = queue; + } else { + selInfo = queue.removeLast(); + } + + trimIdleSelectors(System.currentTimeMillis()); + return selInfo; + } + + /** + * puts selector back at the end of LRU list of free selectos. Also invokes + * trimIdleSelectors(). + * + * @param info + */ + private synchronized void release(SelectorInfo info) { + long now = System.currentTimeMillis(); + trimIdleSelectors(now); + info.lastActivityTime = now; + info.queue.addLast(info); + } + + /** + * Closes selectors that are idle for IDLE_TIMEOUT (10 sec). It does not + * traverse the whole list, just over the one that have crossed the timeout. + */ + private void trimIdleSelectors(long now) { + long cutoff = now - IDLE_TIMEOUT; + + for (ProviderInfo pList = providerList; pList != null; pList = pList.next) { + if (pList.queue.isEmpty()) { + continue; + } + for (Iterator it = pList.queue.iterator(); it.hasNext();) { + SelectorInfo info = it.next(); + if (info.lastActivityTime > cutoff) { + break; + } + it.remove(); + info.close(); + } + } + } + } +} Index: commons/src/main/java/org/apache/hama/commons/io/SocketInputStream.java =================================================================== --- commons/src/main/java/org/apache/hama/commons/io/SocketInputStream.java (revision 0) +++ commons/src/main/java/org/apache/hama/commons/io/SocketInputStream.java (revision 0) @@ -0,0 +1,168 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.hama.commons.io; + +import java.io.IOException; +import java.io.InputStream; +import java.net.Socket; +import java.net.SocketTimeoutException; +import java.nio.ByteBuffer; +import java.nio.channels.FileChannel; +import java.nio.channels.ReadableByteChannel; +import java.nio.channels.SelectableChannel; +import java.nio.channels.SelectionKey; + +/** + * This implements an input stream that can have a timeout while reading. This + * sets non-blocking flag on the socket channel. So after create this object, + * read() on {@link Socket#getInputStream()} and write() on + * {@link Socket#getOutputStream()} for the associated socket will throw + * IllegalBlockingModeException. Please use {@link SocketOutputStream} for + * writing. + */ +public class SocketInputStream extends InputStream implements + ReadableByteChannel { + + private Reader reader; + + private static class Reader extends SocketIOWithTimeout { + ReadableByteChannel channel; + + Reader(ReadableByteChannel channel, long timeout) throws IOException { + super((SelectableChannel) channel, timeout); + this.channel = channel; + } + + int performIO(ByteBuffer buf) throws IOException { + return channel.read(buf); + } + } + + /** + * Create a new input stream with the given timeout. If the timeout is zero, + * it will be treated as infinite timeout. The socket's channel will be + * configured to be non-blocking. + * + * @param channel Channel for reading, should also be a + * {@link SelectableChannel}. The channel will be configured to be + * non-blocking. + * @param timeout timeout in milliseconds. must not be negative. + * @throws IOException + */ + public SocketInputStream(ReadableByteChannel channel, long timeout) + throws IOException { + SocketIOWithTimeout.checkChannelValidity(channel); + reader = new Reader(channel, timeout); + } + + /** + * Same as SocketInputStream(socket.getChannel(), timeout):
+ *
+ * + * Create a new input stream with the given timeout. If the timeout is zero, + * it will be treated as infinite timeout. The socket's channel will be + * configured to be non-blocking. + * + * @see SocketInputStream#SocketInputStream(ReadableByteChannel, long) + * + * @param socket should have a channel associated with it. + * @param timeout timeout timeout in milliseconds. must not be negative. + * @throws IOException + */ + public SocketInputStream(Socket socket, long timeout) throws IOException { + this(socket.getChannel(), timeout); + } + + /** + * Same as SocketInputStream(socket.getChannel(), socket.getSoTimeout()) :
+ *
+ * + * Create a new input stream with the given timeout. If the timeout is zero, + * it will be treated as infinite timeout. The socket's channel will be + * configured to be non-blocking. + * + * @see SocketInputStream#SocketInputStream(ReadableByteChannel, long) + * + * @param socket should have a channel associated with it. + * @throws IOException + */ + public SocketInputStream(Socket socket) throws IOException { + this(socket.getChannel(), socket.getSoTimeout()); + } + + @Override + public int read() throws IOException { + /* + * Allocation can be removed if required. probably no need to optimize or + * encourage single byte read. + */ + byte[] buf = new byte[1]; + int ret = read(buf, 0, 1); + if (ret > 0) { + return (byte) buf[0]; + } + if (ret != -1) { + // unexpected + throw new IOException("Could not read from stream"); + } + return ret; + } + + public int read(byte[] b, int off, int len) throws IOException { + return read(ByteBuffer.wrap(b, off, len)); + } + + public synchronized void close() throws IOException { + /* + * close the channel since Socket.getInputStream().close() closes the + * socket. + */ + reader.channel.close(); + reader.close(); + } + + /** + * Returns underlying channel used by inputstream. This is useful in certain + * cases like channel for + * {@link FileChannel#transferFrom(ReadableByteChannel, long, long)}. + */ + public ReadableByteChannel getChannel() { + return reader.channel; + } + + // ReadableByteChannel interface + + public boolean isOpen() { + return reader.isOpen(); + } + + public int read(ByteBuffer dst) throws IOException { + return reader.doIO(dst, SelectionKey.OP_READ); + } + + /** + * waits for the underlying channel to be ready for reading. The timeout + * specified for this stream applies to this wait. + * + * @throws SocketTimeoutException if select on the channel times out. + * @throws IOException if any other I/O error occurs. + */ + public void waitForReadable() throws IOException { + reader.waitForIO(SelectionKey.OP_READ); + } +} Index: commons/src/main/java/org/apache/hama/commons/io/SocketOutputStream.java =================================================================== --- commons/src/main/java/org/apache/hama/commons/io/SocketOutputStream.java (revision 0) +++ commons/src/main/java/org/apache/hama/commons/io/SocketOutputStream.java (revision 0) @@ -0,0 +1,213 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.hama.commons.io; + +import java.io.EOFException; +import java.io.IOException; +import java.io.OutputStream; +import java.net.Socket; +import java.net.SocketTimeoutException; +import java.nio.ByteBuffer; +import java.nio.channels.FileChannel; +import java.nio.channels.SelectableChannel; +import java.nio.channels.SelectionKey; +import java.nio.channels.WritableByteChannel; + +/** + * This implements an output stream that can have a timeout while writing. This + * sets non-blocking flag on the socket channel. So after creating this object , + * read() on {@link Socket#getInputStream()} and write() on + * {@link Socket#getOutputStream()} on the associated socket will throw + * llegalBlockingModeException. Please use {@link SocketInputStream} for + * reading. + */ +public class SocketOutputStream extends OutputStream implements + WritableByteChannel { + + private Writer writer; + + private static class Writer extends SocketIOWithTimeout { + WritableByteChannel channel; + + Writer(WritableByteChannel channel, long timeout) throws IOException { + super((SelectableChannel) channel, timeout); + this.channel = channel; + } + + int performIO(ByteBuffer buf) throws IOException { + return channel.write(buf); + } + } + + /** + * Create a new ouput stream with the given timeout. If the timeout is zero, + * it will be treated as infinite timeout. The socket's channel will be + * configured to be non-blocking. + * + * @param channel Channel for writing, should also be a + * {@link SelectableChannel}. The channel will be configured to be + * non-blocking. + * @param timeout timeout in milliseconds. must not be negative. + * @throws IOException + */ + public SocketOutputStream(WritableByteChannel channel, long timeout) + throws IOException { + SocketIOWithTimeout.checkChannelValidity(channel); + writer = new Writer(channel, timeout); + } + + /** + * Same as SocketOutputStream(socket.getChannel(), timeout):
+ *
+ * + * Create a new ouput stream with the given timeout. If the timeout is zero, + * it will be treated as infinite timeout. The socket's channel will be + * configured to be non-blocking. + * + * @see SocketOutputStream#SocketOutputStream(WritableByteChannel, long) + * + * @param socket should have a channel associated with it. + * @param timeout timeout timeout in milliseconds. must not be negative. + * @throws IOException + */ + public SocketOutputStream(Socket socket, long timeout) throws IOException { + this(socket.getChannel(), timeout); + } + + public void write(int b) throws IOException { + /* + * If we need to, we can optimize this allocation. probably no need to + * optimize or encourage single byte writes. + */ + byte[] buf = new byte[1]; + buf[0] = (byte) b; + write(buf, 0, 1); + } + + public void write(byte[] b, int off, int len) throws IOException { + ByteBuffer buf = ByteBuffer.wrap(b, off, len); + while (buf.hasRemaining()) { + try { + if (write(buf) < 0) { + throw new IOException("The stream is closed"); + } + } catch (IOException e) { + /* + * Unlike read, write can not inform user of partial writes. So will + * close this if there was a partial write. + */ + if (buf.capacity() > buf.remaining()) { + writer.close(); + } + throw e; + } + } + } + + public synchronized void close() throws IOException { + /* + * close the channel since Socket.getOuputStream().close() closes the + * socket. + */ + writer.channel.close(); + writer.close(); + } + + /** + * Returns underlying channel used by this stream. This is useful in certain + * cases like channel for + * {@link FileChannel#transferTo(long, long, WritableByteChannel)} + */ + public WritableByteChannel getChannel() { + return writer.channel; + } + + // WritableByteChannle interface + + public boolean isOpen() { + return writer.isOpen(); + } + + public int write(ByteBuffer src) throws IOException { + return writer.doIO(src, SelectionKey.OP_WRITE); + } + + /** + * waits for the underlying channel to be ready for writing. The timeout + * specified for this stream applies to this wait. + * + * @throws SocketTimeoutException if select on the channel times out. + * @throws IOException if any other I/O error occurs. + */ + public void waitForWritable() throws IOException { + writer.waitForIO(SelectionKey.OP_WRITE); + } + + /** + * Transfers data from FileChannel using + * {@link FileChannel#transferTo(long, long, WritableByteChannel)}. + * + * Similar to readFully(), this waits till requested amount of data is + * transfered. + * + * @param fileCh FileChannel to transfer data from. + * @param position position within the channel where the transfer begins + * @param count number of bytes to transfer. + * + * @throws EOFException If end of input file is reached before requested + * number of bytes are transfered. + * + * @throws SocketTimeoutException If this channel blocks transfer longer than + * timeout for this stream. + * + * @throws IOException Includes any exception thrown by + * {@link FileChannel#transferTo(long, long, WritableByteChannel)}. + */ + public void transferToFully(FileChannel fileCh, long position, int count) + throws IOException { + + while (count > 0) { + /* + * Ideally we should wait after transferTo returns 0. But because of a bug + * in JRE on Linux (http://bugs.sun.com/view_bug.do?bug_id=5103988), which + * throws an exception instead of returning 0, we wait for the channel to + * be writable before writing to it. If you ever see IOException with + * message "Resource temporarily unavailable" thrown here, please let us + * know. Once we move to JAVA SE 7, wait should be moved to correct place. + */ + waitForWritable(); + int nTransfered = (int) fileCh.transferTo(position, count, getChannel()); + + if (nTransfered == 0) { + // check if end of file is reached. + if (position >= fileCh.size()) { + throw new EOFException("EOF Reached. file size is " + fileCh.size() + + " and " + count + " more bytes left to be " + "transfered."); + } + // otherwise assume the socket is full. + // waitForWritable(); // see comment above. + } else if (nTransfered < 0) { + throw new IOException("Unexpected return of " + nTransfered + + " from transferTo()"); + } else { + position += nTransfered; + count -= nTransfered; + } + } + } +} Index: commons/src/main/java/org/apache/hama/commons/io/writable/SparseVectorWritable.java =================================================================== --- commons/src/main/java/org/apache/hama/commons/io/writable/SparseVectorWritable.java (revision 0) +++ commons/src/main/java/org/apache/hama/commons/io/writable/SparseVectorWritable.java (revision 0) @@ -0,0 +1,105 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.hama.commons.io.writable; + +import java.io.DataInput; +import java.io.DataOutput; +import java.io.IOException; +import java.util.ArrayList; +import java.util.List; + +import org.apache.hadoop.io.Writable; + +/** + * This class represents sparse vector. It will give improvement in memory + * consumption in case of vectors which sparsity is close to zero. Can be used + * in SpMV for representing input matrix rows efficiently. Internally represents + * values as list of indeces and list of values. + */ +public class SparseVectorWritable implements Writable { + + private Integer size; + private List indeces; + private List values; + + public SparseVectorWritable() { + indeces = new ArrayList(); + values = new ArrayList(); + } + + public void clear() { + indeces = new ArrayList(); + values = new ArrayList(); + } + + public void addCell(int index, double value) { + indeces.add(index); + values.add(value); + } + + public void setSize(int size) { + this.size = size; + } + + public int getSize() { + if (size != null) + return size; + return indeces.size(); + } + + public List getIndeces() { + return indeces; + } + + public List getValues() { + return values; + } + + @Override + public void readFields(DataInput in) throws IOException { + clear(); + int size = in.readInt(); + int len = in.readInt(); + setSize(size); + for (int i = 0; i < len; i++) { + int index = in.readInt(); + double value = in.readDouble(); + this.addCell(index, value); + } + } + + @Override + public void write(DataOutput out) throws IOException { + out.writeInt(getSize()); + out.writeInt(indeces.size()); + for (int i = 0; i < indeces.size(); i++) { + out.writeInt(indeces.get(i)); + out.writeDouble(values.get(i)); + } + } + + @Override + public String toString() { + StringBuilder st = new StringBuilder(); + st.append(" " + getSize() + " " + indeces.size()); + for (int i = 0; i < indeces.size(); i++) + st.append(" " + indeces.get(i) + " " + values.get(i)); + return st.toString(); + } + +} Index: commons/src/main/java/org/apache/hama/commons/io/writable/DenseVectorWritable.java =================================================================== --- commons/src/main/java/org/apache/hama/commons/io/writable/DenseVectorWritable.java (revision 0) +++ commons/src/main/java/org/apache/hama/commons/io/writable/DenseVectorWritable.java (revision 0) @@ -0,0 +1,87 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.hama.commons.io.writable; + +import java.io.DataInput; +import java.io.DataOutput; +import java.io.IOException; + +import org.apache.hadoop.io.Writable; + +/** + * This class represents dense vector. It will improve memory consumption up to + * two times in comparison to SparseVectorWritable in case of vectors which + * sparsity is close to 1. Internally represents vector values as array. Can be + * used in SpMV for representation of input and output vector. + */ +public class DenseVectorWritable implements Writable { + + private double values[]; + + public DenseVectorWritable() { + values = new double[0]; + } + + public int getSize() { + return values.length; + } + + public void setSize(int size) { + values = new double[size]; + } + + public double get(int index) { + return values[index]; + } + + public void addCell(int index, double value) { + values[index] = value; + } + + @Override + public void readFields(DataInput in) throws IOException { + int size = in.readInt(); + int len = in.readInt(); + setSize(size); + for (int i = 0; i < len; i++) { + int index = in.readInt(); + double value = in.readDouble(); + values[index] = value; + } + } + + @Override + public void write(DataOutput out) throws IOException { + out.writeInt(getSize()); + out.writeInt(getSize()); + for (int i = 0; i < getSize(); i++) { + out.writeInt(i); + out.writeDouble(values[i]); + } + } + + @Override + public String toString() { + StringBuilder st = new StringBuilder(); + st.append(" " + getSize() + " " + getSize()); + for (int i = 0; i < getSize(); i++) + st.append(" " + i + " " + values[i]); + return st.toString(); + } + +} Index: commons/src/main/java/org/apache/hama/commons/io/writable/VectorWritable.java =================================================================== --- commons/src/main/java/org/apache/hama/commons/io/writable/VectorWritable.java (revision 0) +++ commons/src/main/java/org/apache/hama/commons/io/writable/VectorWritable.java (revision 0) @@ -0,0 +1,133 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.hama.commons.io.writable; + +import java.io.DataInput; +import java.io.DataOutput; +import java.io.IOException; + +import org.apache.hadoop.io.WritableComparable; +import org.apache.hama.commons.math.DenseDoubleVector; +import org.apache.hama.commons.math.DoubleVector; + +/** + * Writable for dense vectors. + */ +public final class VectorWritable implements WritableComparable { + + private DoubleVector vector; + + public VectorWritable() { + super(); + } + + public VectorWritable(VectorWritable v) { + this.vector = v.getVector(); + } + + public VectorWritable(DoubleVector v) { + this.vector = v; + } + + @Override + public final void write(DataOutput out) throws IOException { + writeVector(this.vector, out); + } + + @Override + public final void readFields(DataInput in) throws IOException { + this.vector = readVector(in); + } + + @Override + public final int compareTo(VectorWritable o) { + return compareVector(this, o); + } + + @Override + public int hashCode() { + final int prime = 31; + int result = 1; + result = prime * result + ((vector == null) ? 0 : vector.hashCode()); + return result; + } + + @Override + public boolean equals(Object obj) { + if (this == obj) + return true; + if (obj == null) + return false; + if (getClass() != obj.getClass()) + return false; + VectorWritable other = (VectorWritable) obj; + if (vector == null) { + if (other.vector != null) + return false; + } else if (!vector.equals(other.vector)) + return false; + return true; + } + + /** + * @return the embedded vector + */ + public DoubleVector getVector() { + return vector; + } + + @Override + public String toString() { + return vector.toString(); + } + + public static void writeVector(DoubleVector vector, DataOutput out) + throws IOException { + out.writeInt(vector.getLength()); + for (int i = 0; i < vector.getDimension(); i++) { + out.writeDouble(vector.get(i)); + } + } + + public static DoubleVector readVector(DataInput in) throws IOException { + int length = in.readInt(); + DoubleVector vector; + vector = new DenseDoubleVector(length); + for (int i = 0; i < length; i++) { + vector.set(i, in.readDouble()); + } + return vector; + } + + public static int compareVector(VectorWritable a, VectorWritable o) { + return compareVector(a.getVector(), o.getVector()); + } + + public static int compareVector(DoubleVector a, DoubleVector o) { + DoubleVector subtract = a.subtractUnsafe(o); + return (int) subtract.sum(); + } + + public static VectorWritable wrap(DoubleVector a) { + return new VectorWritable(a); + } + + public void set(DoubleVector vector) { + this.vector = vector; + } +} Index: commons/src/main/java/org/apache/hama/commons/io/writable/TextArrayWritable.java =================================================================== --- commons/src/main/java/org/apache/hama/commons/io/writable/TextArrayWritable.java (revision 0) +++ commons/src/main/java/org/apache/hama/commons/io/writable/TextArrayWritable.java (revision 0) @@ -0,0 +1,29 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.hama.commons.io.writable; + +import org.apache.hadoop.io.ArrayWritable; +import org.apache.hadoop.io.Text; + +public class TextArrayWritable extends ArrayWritable { + + public TextArrayWritable() { + super(Text.class); + } + +} Index: commons/src/main/java/org/apache/hama/commons/io/writable/MatrixWritable.java =================================================================== --- commons/src/main/java/org/apache/hama/commons/io/writable/MatrixWritable.java (revision 0) +++ commons/src/main/java/org/apache/hama/commons/io/writable/MatrixWritable.java (revision 0) @@ -0,0 +1,73 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.hama.commons.io.writable; + +import java.io.DataInput; +import java.io.DataOutput; +import java.io.IOException; + +import org.apache.hadoop.io.Writable; +import org.apache.hama.commons.math.DenseDoubleMatrix; +import org.apache.hama.commons.math.DoubleMatrix; + +/** + * Majorly designed for dense matrices, can be extended for sparse ones as well. + */ +public final class MatrixWritable implements Writable { + + private DoubleMatrix mat; + + public MatrixWritable() { + } + + public MatrixWritable(DoubleMatrix mat) { + this.mat = mat; + + } + + @Override + public void readFields(DataInput in) throws IOException { + mat = read(in); + } + + @Override + public void write(DataOutput out) throws IOException { + write(mat, out); + } + + public static void write(DoubleMatrix mat, DataOutput out) throws IOException { + out.writeInt(mat.getRowCount()); + out.writeInt(mat.getColumnCount()); + for (int row = 0; row < mat.getRowCount(); row++) { + for (int col = 0; col < mat.getColumnCount(); col++) { + out.writeDouble(mat.get(row, col)); + } + } + } + + public static DoubleMatrix read(DataInput in) throws IOException { + DoubleMatrix mat = new DenseDoubleMatrix(in.readInt(), in.readInt()); + for (int row = 0; row < mat.getRowCount(); row++) { + for (int col = 0; col < mat.getColumnCount(); col++) { + mat.set(row, col, in.readDouble()); + } + } + return mat; + } + +} Index: commons/src/main/java/org/apache/hama/commons/io/writable/StringArrayWritable.java =================================================================== --- commons/src/main/java/org/apache/hama/commons/io/writable/StringArrayWritable.java (revision 0) +++ commons/src/main/java/org/apache/hama/commons/io/writable/StringArrayWritable.java (revision 0) @@ -0,0 +1,65 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.hama.commons.io.writable; + +import java.io.DataInput; +import java.io.DataOutput; +import java.io.IOException; + +import org.apache.hadoop.io.Writable; + +/** + * Custom writable for string arrays, because ArrayWritable has no default + * constructor and is broken. + * + */ +public class StringArrayWritable implements Writable { + + private String[] array; + + public StringArrayWritable() { + super(); + } + + public StringArrayWritable(String[] array) { + super(); + this.array = array; + } + + // no defensive copy needed because this always comes from an rpc call. + public String[] get() { + return array; + } + + @Override + public void write(DataOutput out) throws IOException { + out.writeInt(array.length); + for (String s : array) { + out.writeUTF(s); + } + } + + @Override + public void readFields(DataInput in) throws IOException { + array = new String[in.readInt()]; + for (int i = 0; i < array.length; i++) { + array[i] = in.readUTF(); + } + } + +} Index: commons/src/main/java/org/apache/hama/commons/math/CrossEntropy.java =================================================================== --- commons/src/main/java/org/apache/hama/commons/math/CrossEntropy.java (revision 0) +++ commons/src/main/java/org/apache/hama/commons/math/CrossEntropy.java (revision 0) @@ -0,0 +1,58 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.hama.commons.math; + +/** + * The cross entropy cost function. + * + *
+ * cost(t, y) = - t * log(y) - (1 - t) * log(1 - y),
+ * where t denotes the target value, y denotes the estimated value.
+ * 
+ */ +public class CrossEntropy extends DoubleDoubleFunction { + + @Override + public double apply(double target, double actual) { + double adjustedTarget = (target == 0 ? 0.000001 : target); + adjustedTarget = (target == 1.0 ? 0.999999 : target); + double adjustedActual = (actual == 0 ? 0.000001 : actual); + adjustedActual = (actual == 1 ? 0.999999 : actual); + return -adjustedTarget * Math.log(adjustedActual) - (1 - adjustedTarget) + * Math.log(1 - adjustedActual); + } + + @Override + public double applyDerivative(double target, double actual) { + double adjustedTarget = target; + double adjustedActual = actual; + if (adjustedActual == 1) { + adjustedActual = 0.999; + } else if (actual == 0) { + adjustedActual = 0.001; + } + if (adjustedTarget == 1) { + adjustedTarget = 0.999; + } else if (adjustedTarget == 0) { + adjustedTarget = 0.001; + } + return -adjustedTarget / adjustedActual + (1 - adjustedTarget) + / (1 - adjustedActual); + } + +} Index: commons/src/main/java/org/apache/hama/commons/math/DoubleDoubleFunction.java =================================================================== --- commons/src/main/java/org/apache/hama/commons/math/DoubleDoubleFunction.java (revision 0) +++ commons/src/main/java/org/apache/hama/commons/math/DoubleDoubleFunction.java (revision 0) @@ -0,0 +1,45 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.hama.commons.math; + +/** + * A double double function takes two arguments. A vector or matrix can apply + * the double function to each element. + * + */ +public abstract class DoubleDoubleFunction extends Function { + + /** + * Apply the function to elements to two given arguments. + * + * @param x1 + * @param x2 + * @return The result based on the calculation on two arguments. + */ + public abstract double apply(double x1, double x2); + + /** + * Apply the derivative of this function to two given arguments. + * + * @param x1 + * @param x2 + * @return The result based on the calculation on two arguments. + */ + public abstract double applyDerivative(double x1, double x2); + +} Index: commons/src/main/java/org/apache/hama/commons/math/DenseDoubleVector.java =================================================================== --- commons/src/main/java/org/apache/hama/commons/math/DenseDoubleVector.java (revision 0) +++ commons/src/main/java/org/apache/hama/commons/math/DenseDoubleVector.java (revision 0) @@ -0,0 +1,739 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.hama.commons.math; + +import java.math.BigDecimal; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.Comparator; +import java.util.Iterator; +import java.util.List; + +import com.google.common.base.Preconditions; +import com.google.common.collect.AbstractIterator; + +/** + * Dense double vector implementation. + */ +public final class DenseDoubleVector implements DoubleVector { + + private final double[] vector; + + /** + * Creates a new vector with the given length. + */ + public DenseDoubleVector(int length) { + this.vector = new double[length]; + } + + /** + * Creates a new vector with the given length and default value. + */ + public DenseDoubleVector(int length, double val) { + this(length); + Arrays.fill(vector, val); + } + + /** + * Creates a new vector with the given array. + */ + public DenseDoubleVector(double[] arr) { + this.vector = arr; + } + + /** + * Creates a new vector with the given array and the last value f1. + */ + public DenseDoubleVector(double[] array, double f1) { + this.vector = new double[array.length + 1]; + System.arraycopy(array, 0, this.vector, 0, array.length); + this.vector[array.length] = f1; + } + + /* + * (non-Javadoc) + * @see de.jungblut.math.DoubleVector#get(int) + */ + @Override + public final double get(int index) { + return vector[index]; + } + + /* + * (non-Javadoc) + * @see de.jungblut.math.DoubleVector#getLength() + */ + @Override + public final int getLength() { + return vector.length; + } + + /* + * (non-Javadoc) + * @see de.jungblut.math.DoubleVector#getDimension() + */ + @Override + public int getDimension() { + return getLength(); + } + + /* + * (non-Javadoc) + * @see de.jungblut.math.DoubleVector#set(int, double) + */ + @Override + public final void set(int index, double value) { + vector[index] = value; + } + + /** + * {@inheritDoc} + */ + @Override + public DoubleVector applyToElements(DoubleFunction func) { + for (int i = 0; i < vector.length; i++) { + this.vector[i] = func.apply(vector[i]); + } + return this; + } + + /** + * {@inheritDoc} + */ + @Override + public DoubleVector applyToElements(DoubleVector other, + DoubleDoubleFunction func) { + for (int i = 0; i < vector.length; i++) { + this.vector[i] = func.apply(vector[i], other.get(i)); + } + return this; + } + + /* + * (non-Javadoc) + * @see de.jungblut.math.DoubleVector#apply(de.jungblut.math.function. + * DoubleVectorFunction) + */ + @Deprecated + @Override + public DoubleVector apply(DoubleVectorFunction func) { + DenseDoubleVector newV = new DenseDoubleVector(this.vector); + for (int i = 0; i < vector.length; i++) { + newV.vector[i] = func.calculate(i, vector[i]); + } + return newV; + } + + /* + * (non-Javadoc) + * @see de.jungblut.math.DoubleVector#apply(de.jungblut.math.DoubleVector, + * de.jungblut.math.function.DoubleDoubleVectorFunction) + */ + @Deprecated + @Override + public DoubleVector apply(DoubleVector other, DoubleDoubleVectorFunction func) { + DenseDoubleVector newV = (DenseDoubleVector) deepCopy(); + for (int i = 0; i < vector.length; i++) { + newV.vector[i] = func.calculate(i, vector[i], other.get(i)); + } + return newV; + } + + /* + * (non-Javadoc) + * @see de.jungblut.math.DoubleVector#add(de.jungblut.math.DoubleVector) + */ + @Override + public final DoubleVector addUnsafe(DoubleVector v) { + DenseDoubleVector newv = new DenseDoubleVector(v.getLength()); + for (int i = 0; i < v.getLength(); i++) { + newv.set(i, this.get(i) + v.get(i)); + } + return newv; + } + + /* + * (non-Javadoc) + * @see de.jungblut.math.DoubleVector#add(double) + */ + @Override + public final DoubleVector add(double scalar) { + DoubleVector newv = new DenseDoubleVector(this.getLength()); + for (int i = 0; i < this.getLength(); i++) { + newv.set(i, this.get(i) + scalar); + } + return newv; + } + + /* + * (non-Javadoc) + * @see de.jungblut.math.DoubleVector#subtract(de.jungblut.math.DoubleVector) + */ + @Override + public final DoubleVector subtractUnsafe(DoubleVector v) { + DoubleVector newv = new DenseDoubleVector(v.getLength()); + for (int i = 0; i < v.getLength(); i++) { + newv.set(i, this.get(i) - v.get(i)); + } + return newv; + } + + /* + * (non-Javadoc) + * @see de.jungblut.math.DoubleVector#subtract(double) + */ + @Override + public final DoubleVector subtract(double v) { + DenseDoubleVector newv = new DenseDoubleVector(vector.length); + for (int i = 0; i < vector.length; i++) { + newv.set(i, vector[i] - v); + } + return newv; + } + + /* + * (non-Javadoc) + * @see de.jungblut.math.DoubleVector#subtractFrom(double) + */ + @Override + public final DoubleVector subtractFrom(double v) { + DenseDoubleVector newv = new DenseDoubleVector(vector.length); + for (int i = 0; i < vector.length; i++) { + newv.set(i, v - vector[i]); + } + return newv; + } + + /* + * (non-Javadoc) + * @see de.jungblut.math.DoubleVector#multiply(double) + */ + @Override + public DoubleVector multiply(double scalar) { + DoubleVector v = new DenseDoubleVector(this.getLength()); + for (int i = 0; i < v.getLength(); i++) { + v.set(i, this.get(i) * scalar); + } + return v; + } + + /* + * (non-Javadoc) + * @see de.jungblut.math.DoubleVector#multiply(de.jungblut.math.DoubleVector) + */ + @Override + public DoubleVector multiplyUnsafe(DoubleVector vector) { + DoubleVector v = new DenseDoubleVector(this.getLength()); + for (int i = 0; i < v.getLength(); i++) { + v.set(i, this.get(i) * vector.get(i)); + } + return v; + } + + @Override + public DoubleVector multiply(DoubleMatrix matrix) { + Preconditions.checkArgument(this.vector.length == matrix.getRowCount(), + "Dimension mismatch when multiply a vector to a matrix."); + return this.multiplyUnsafe(matrix); + } + + @Override + public DoubleVector multiplyUnsafe(DoubleMatrix matrix) { + DoubleVector vec = new DenseDoubleVector(matrix.getColumnCount()); + for (int i = 0; i < vec.getDimension(); ++i) { + vec.set(i, this.multiplyUnsafe(matrix.getColumnVector(i)).sum()); + } + return vec; + } + + /* + * (non-Javadoc) + * @see de.jungblut.math.DoubleVector#divide(double) + */ + @Override + public DoubleVector divide(double scalar) { + DenseDoubleVector v = new DenseDoubleVector(this.getLength()); + for (int i = 0; i < v.getLength(); i++) { + v.set(i, this.get(i) / scalar); + } + return v; + } + + /* + * (non-Javadoc) + * @see de.jungblut.math.DoubleVector#pow(int) + */ + @Override + public DoubleVector pow(int x) { + DenseDoubleVector v = new DenseDoubleVector(getLength()); + for (int i = 0; i < v.getLength(); i++) { + double value = 0.0d; + // it is faster to multiply when we having ^2 + if (x == 2) { + value = vector[i] * vector[i]; + } else { + value = Math.pow(vector[i], x); + } + v.set(i, value); + } + return v; + } + + /* + * (non-Javadoc) + * @see de.jungblut.math.DoubleVector#sqrt() + */ + @Override + public DoubleVector sqrt() { + DoubleVector v = new DenseDoubleVector(getLength()); + for (int i = 0; i < v.getLength(); i++) { + v.set(i, Math.sqrt(vector[i])); + } + return v; + } + + /* + * (non-Javadoc) + * @see de.jungblut.math.DoubleVector#sum() + */ + @Override + public double sum() { + double sum = 0.0d; + for (double aVector : vector) { + sum += aVector; + } + return sum; + } + + /* + * (non-Javadoc) + * @see de.jungblut.math.DoubleVector#abs() + */ + @Override + public DoubleVector abs() { + DoubleVector v = new DenseDoubleVector(getLength()); + for (int i = 0; i < v.getLength(); i++) { + v.set(i, Math.abs(vector[i])); + } + return v; + } + + /* + * (non-Javadoc) + * @see de.jungblut.math.DoubleVector#divideFrom(double) + */ + @Override + public DoubleVector divideFrom(double scalar) { + DoubleVector v = new DenseDoubleVector(this.getLength()); + for (int i = 0; i < v.getLength(); i++) { + if (this.get(i) != 0.0d) { + double result = scalar / this.get(i); + v.set(i, result); + } else { + v.set(i, 0.0d); + } + } + return v; + } + + /* + * (non-Javadoc) + * @see de.jungblut.math.DoubleVector#dot(de.jungblut.math.DoubleVector) + */ + @Override + public double dotUnsafe(DoubleVector vector) { + BigDecimal dotProduct = BigDecimal.valueOf(0.0d); + for (int i = 0; i < getLength(); i++) { + dotProduct = dotProduct.add(BigDecimal.valueOf(this.get(i)).multiply(BigDecimal.valueOf(vector.get(i)))); + } + return dotProduct.doubleValue(); + } + + /* + * (non-Javadoc) + * @see de.jungblut.math.DoubleVector#slice(int) + */ + @Override + public DoubleVector slice(int length) { + return slice(0, length - 1); + } + + @Override + public DoubleVector sliceUnsafe(int length) { + return sliceUnsafe(0, length - 1); + } + + /* + * (non-Javadoc) + * @see de.jungblut.math.DoubleVector#slice(int, int) + */ + @Override + public DoubleVector slice(int start, int end) { + Preconditions.checkArgument(start >= 0 && start <= end + && end < vector.length, "The given from and to is invalid"); + + return sliceUnsafe(start, end); + } + + /** + * {@inheritDoc} + */ + @Override + public DoubleVector sliceUnsafe(int start, int end) { + DoubleVector newVec = new DenseDoubleVector(end - start + 1); + for (int i = start, j = 0; i <= end; ++i, ++j) { + newVec.set(j, vector[i]); + } + + return newVec; + } + + /* + * (non-Javadoc) + * @see de.jungblut.math.DoubleVector#max() + */ + @Override + public double max() { + double max = -Double.MAX_VALUE; + for (int i = 0; i < getLength(); i++) { + double d = vector[i]; + if (d > max) { + max = d; + } + } + return max; + } + + /** + * @return the index where the maximum resides. + */ + public int maxIndex() { + double max = -Double.MAX_VALUE; + int maxIndex = 0; + for (int i = 0; i < getLength(); i++) { + double d = vector[i]; + if (d > max) { + max = d; + maxIndex = i; + } + } + return maxIndex; + } + + /* + * (non-Javadoc) + * @see de.jungblut.math.DoubleVector#min() + */ + @Override + public double min() { + double min = Double.MAX_VALUE; + for (int i = 0; i < getLength(); i++) { + double d = vector[i]; + if (d < min) { + min = d; + } + } + return min; + } + + /** + * @return the index where the minimum resides. + */ + public int minIndex() { + double min = Double.MAX_VALUE; + int minIndex = 0; + for (int i = 0; i < getLength(); i++) { + double d = vector[i]; + if (d < min) { + min = d; + minIndex = i; + } + } + return minIndex; + } + + /** + * @return a new vector which has rinted each element. + */ + public DenseDoubleVector rint() { + DenseDoubleVector v = new DenseDoubleVector(getLength()); + for (int i = 0; i < getLength(); i++) { + double d = vector[i]; + v.set(i, Math.rint(d)); + } + return v; + } + + /** + * @return a new vector which has rounded each element. + */ + public DenseDoubleVector round() { + DenseDoubleVector v = new DenseDoubleVector(getLength()); + for (int i = 0; i < getLength(); i++) { + double d = vector[i]; + v.set(i, Math.round(d)); + } + return v; + } + + /** + * @return a new vector which has ceiled each element. + */ + public DenseDoubleVector ceil() { + DenseDoubleVector v = new DenseDoubleVector(getLength()); + for (int i = 0; i < getLength(); i++) { + double d = vector[i]; + v.set(i, Math.ceil(d)); + } + return v; + } + + /* + * (non-Javadoc) + * @see de.jungblut.math.DoubleVector#toArray() + */ + @Override + public final double[] toArray() { + return vector; + } + + /* + * (non-Javadoc) + * @see de.jungblut.math.DoubleVector#isSparse() + */ + @Override + public boolean isSparse() { + return false; + } + + /* + * (non-Javadoc) + * @see de.jungblut.math.DoubleVector#deepCopy() + */ + @Override + public DoubleVector deepCopy() { + final double[] src = vector; + final double[] dest = new double[vector.length]; + System.arraycopy(src, 0, dest, 0, vector.length); + return new DenseDoubleVector(dest); + } + + /* + * (non-Javadoc) + * @see de.jungblut.math.DoubleVector#iterateNonZero() + */ + @Override + public Iterator iterateNonZero() { + return new NonZeroIterator(); + } + + /* + * (non-Javadoc) + * @see de.jungblut.math.DoubleVector#iterate() + */ + @Override + public Iterator iterate() { + return new DefaultIterator(); + } + + @Override + public final String toString() { + if (getLength() < 20) { + return Arrays.toString(vector); + } else { + return getLength() + "x1"; + } + } + + @Override + public int hashCode() { + final int prime = 31; + int result = 1; + result = prime * result + Arrays.hashCode(vector); + return result; + } + + @Override + public boolean equals(Object obj) { + if (this == obj) + return true; + if (obj == null) + return false; + if (getClass() != obj.getClass()) + return false; + DenseDoubleVector other = (DenseDoubleVector) obj; + return Arrays.equals(vector, other.vector); + } + + /** + * Non-zero iterator for vector elements. + */ + private final class NonZeroIterator extends + AbstractIterator { + + private final DoubleVectorElement element = new DoubleVectorElement(); + private final double[] array; + private int currentIndex = 0; + + private NonZeroIterator() { + this.array = vector; + } + + @Override + protected final DoubleVectorElement computeNext() { + while (array[currentIndex] == 0.0d) { + currentIndex++; + if (currentIndex >= array.length) + return endOfData(); + } + element.setIndex(currentIndex); + element.setValue(array[currentIndex]); + return element; + } + } + + /** + * Iterator for all elements. + */ + private final class DefaultIterator extends + AbstractIterator { + + private final DoubleVectorElement element = new DoubleVectorElement(); + private final double[] array; + private int currentIndex = 0; + + private DefaultIterator() { + this.array = vector; + } + + @Override + protected final DoubleVectorElement computeNext() { + if (currentIndex < array.length) { + element.setIndex(currentIndex); + element.setValue(array[currentIndex]); + currentIndex++; + return element; + } else { + return endOfData(); + } + } + + } + + /** + * @return a new vector with dimension num and a default value of 1. + */ + public static DenseDoubleVector ones(int num) { + return new DenseDoubleVector(num, 1.0d); + } + + /** + * @return a new vector filled from index, to index, with a given stepsize. + */ + public static DenseDoubleVector fromUpTo(double from, double to, + double stepsize) { + DenseDoubleVector v = new DenseDoubleVector( + (int) (Math.round(((to - from) / stepsize) + 0.5))); + + for (int i = 0; i < v.getLength(); i++) { + v.set(i, from + i * stepsize); + } + return v; + } + + /** + * Some crazy sort function. + */ + public static List> sort(DoubleVector vector, + final Comparator scoreComparator) { + List> list = new ArrayList>( + vector.getLength()); + for (int i = 0; i < vector.getLength(); i++) { + list.add(new Tuple(vector.get(i), i)); + } + Collections.sort(list, new Comparator>() { + @Override + public int compare(Tuple o1, Tuple o2) { + return scoreComparator.compare(o1.getFirst(), o2.getFirst()); + } + }); + return list; + } + + @Override + public boolean isNamed() { + return false; + } + + @Override + public String getName() { + return null; + } + + /* + * (non-Javadoc) + * @see org.apache.hama.ml.math.DoubleVector#safeAdd(org.apache.hama.ml.math. + * DoubleVector) + */ + @Override + public DoubleVector add(DoubleVector vector) { + Preconditions.checkArgument(this.vector.length == vector.getDimension(), + "Dimensions of two vectors do not equal."); + return this.addUnsafe(vector); + } + + /* + * (non-Javadoc) + * @see + * org.apache.hama.ml.math.DoubleVector#safeSubtract(org.apache.hama.ml.math + * .DoubleVector) + */ + @Override + public DoubleVector subtract(DoubleVector vector) { + Preconditions.checkArgument(this.vector.length == vector.getDimension(), + "Dimensions of two vectors do not equal."); + return this.subtractUnsafe(vector); + } + + /* + * (non-Javadoc) + * @see + * org.apache.hama.ml.math.DoubleVector#safeMultiplay(org.apache.hama.ml.math + * .DoubleVector) + */ + @Override + public DoubleVector multiply(DoubleVector vector) { + Preconditions.checkArgument(this.vector.length == vector.getDimension(), + "Dimensions of two vectors do not equal."); + return this.multiplyUnsafe(vector); + } + + /* + * (non-Javadoc) + * @see org.apache.hama.ml.math.DoubleVector#safeDot(org.apache.hama.ml.math. + * DoubleVector) + */ + @Override + public double dot(DoubleVector vector) { + Preconditions.checkArgument(this.vector.length == vector.getDimension(), + "Dimensions of two vectors do not equal."); + return this.dotUnsafe(vector); + } + +} Index: commons/src/main/java/org/apache/hama/commons/math/FunctionFactory.java =================================================================== --- commons/src/main/java/org/apache/hama/commons/math/FunctionFactory.java (revision 0) +++ commons/src/main/java/org/apache/hama/commons/math/FunctionFactory.java (revision 0) @@ -0,0 +1,65 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.hama.commons.math; + +/** + * Factory to create the functions. + * + */ +public class FunctionFactory { + + /** + * Create a double function with specified name. + * + * @param functionName + * @return + */ + public static DoubleFunction createDoubleFunction(String functionName) { + if (functionName.equalsIgnoreCase(Sigmoid.class.getSimpleName())) { + return new Sigmoid(); + } else if (functionName.equalsIgnoreCase(Tanh.class.getSimpleName())) { + return new Tanh(); + } else if (functionName.equalsIgnoreCase(IdentityFunction.class + .getSimpleName())) { + return new IdentityFunction(); + } + + throw new IllegalArgumentException(String.format( + "No double function with name '%s' exists.", functionName)); + } + + /** + * Create a double double function with specified name. + * + * @param functionName + * @return + */ + public static DoubleDoubleFunction createDoubleDoubleFunction( + String functionName) { + if (functionName.equalsIgnoreCase(SquaredError.class.getSimpleName())) { + return new SquaredError(); + } else if (functionName + .equalsIgnoreCase(CrossEntropy.class.getSimpleName())) { + return new CrossEntropy(); + } + + throw new IllegalArgumentException(String.format( + "No double double function with name '%s' exists.", functionName)); + } + +} Index: commons/src/main/java/org/apache/hama/commons/math/DenseDoubleMatrix.java =================================================================== --- commons/src/main/java/org/apache/hama/commons/math/DenseDoubleMatrix.java (revision 0) +++ commons/src/main/java/org/apache/hama/commons/math/DenseDoubleMatrix.java (revision 0) @@ -0,0 +1,904 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.hama.commons.math; + +import java.util.Arrays; +import java.util.HashSet; +import java.util.Random; + +import com.google.common.base.Preconditions; + +/** + * Dense double matrix implementation, internally uses two dimensional double + * arrays. + */ +public final class DenseDoubleMatrix implements DoubleMatrix { + + protected final double[][] matrix; + protected final int numRows; + protected final int numColumns; + + /** + * Creates a new empty matrix from the rows and columns. + * + * @param rows the num of rows. + * @param columns the num of columns. + */ + public DenseDoubleMatrix(int rows, int columns) { + this.numRows = rows; + this.numColumns = columns; + this.matrix = new double[rows][columns]; + } + + /** + * Creates a new empty matrix from the rows and columns filled with the given + * default value. + * + * @param rows the num of rows. + * @param columns the num of columns. + * @param defaultValue the default value. + */ + public DenseDoubleMatrix(int rows, int columns, double defaultValue) { + this.numRows = rows; + this.numColumns = columns; + this.matrix = new double[rows][columns]; + + for (int i = 0; i < numRows; i++) { + Arrays.fill(matrix[i], defaultValue); + } + } + + /** + * Creates a new empty matrix from the rows and columns filled with the given + * random values. + * + * @param rows the num of rows. + * @param columns the num of columns. + * @param rand the random instance to use. + */ + public DenseDoubleMatrix(int rows, int columns, Random rand) { + this.numRows = rows; + this.numColumns = columns; + this.matrix = new double[rows][columns]; + + for (int i = 0; i < numRows; i++) { + for (int j = 0; j < numColumns; j++) { + matrix[i][j] = rand.nextDouble(); + } + } + } + + /** + * Simple copy constructor, but does only bend the reference to this instance. + * + * @param otherMatrix the other matrix. + */ + public DenseDoubleMatrix(double[][] otherMatrix) { + this.matrix = otherMatrix; + this.numRows = otherMatrix.length; + if (matrix.length > 0) + this.numColumns = matrix[0].length; + else + this.numColumns = numRows; + } + + /** + * Generates a matrix out of an vector array. it treats the array entries as + * rows and the vector itself contains the values of the columns. + * + * @param vectorArray the array of vectors. + */ + public DenseDoubleMatrix(DoubleVector[] vectorArray) { + this.matrix = new double[vectorArray.length][]; + this.numRows = vectorArray.length; + + for (int i = 0; i < vectorArray.length; i++) { + this.setRowVector(i, vectorArray[i]); + } + + if (matrix.length > 0) + this.numColumns = matrix[0].length; + else + this.numColumns = numRows; + } + + /** + * Sets the first column of this matrix to the given vector. + * + * @param first the new first column of the given vector + */ + public DenseDoubleMatrix(DenseDoubleVector first) { + this(first.getLength(), 1); + setColumn(0, first.toArray()); + } + + /** + * Copies the given double array v into the first row of this matrix, and + * creates this with the number of given rows and columns. + * + * @param v the values to put into the first row. + * @param rows the number of rows. + * @param columns the number of columns. + */ + public DenseDoubleMatrix(double[] v, int rows, int columns) { + this.matrix = new double[rows][columns]; + + for (int i = 0; i < rows; i++) { + System.arraycopy(v, i * columns, this.matrix[i], 0, columns); + } + + int index = 0; + for (int col = 0; col < columns; col++) { + for (int row = 0; row < rows; row++) { + matrix[row][col] = v[index++]; + } + } + + this.numRows = rows; + this.numColumns = columns; + } + + /** + * Creates a new matrix with the given vector into the first column and the + * other matrix to the other columns. + * + * @param first the new first column. + * @param otherMatrix the other matrix to set on from the second column. + */ + public DenseDoubleMatrix(DenseDoubleVector first, DoubleMatrix otherMatrix) { + this(otherMatrix.getRowCount(), otherMatrix.getColumnCount() + 1); + setColumn(0, first.toArray()); + for (int col = 1; col < otherMatrix.getColumnCount() + 1; col++) + setColumnVector(col, otherMatrix.getColumnVector(col - 1)); + } + + /* + * (non-Javadoc) + * @see de.jungblut.math.DoubleMatrix#get(int, int) + */ + @Override + public final double get(int row, int col) { + return this.matrix[row][col]; + } + + /** + * Gets a whole column of the matrix as a double array. + */ + public final double[] getColumn(int col) { + final double[] column = new double[numRows]; + for (int r = 0; r < numRows; r++) { + column[r] = matrix[r][col]; + } + return column; + } + + /* + * (non-Javadoc) + * @see de.jungblut.math.DoubleMatrix#getColumnCount() + */ + @Override + public final int getColumnCount() { + return numColumns; + } + + /* + * (non-Javadoc) + * @see de.jungblut.math.DoubleMatrix#getColumnVector(int) + */ + @Override + public final DoubleVector getColumnVector(int col) { + return new DenseDoubleVector(getColumn(col)); + } + + /** + * Get the matrix as 2-dimensional double array (first dimension is the row, + * second the column) to faster access the values. + */ + public final double[][] getValues() { + return matrix; + } + + /** + * Get a single row of the matrix as a double array. + */ + public final double[] getRow(int row) { + return matrix[row]; + } + + /* + * (non-Javadoc) + * @see de.jungblut.math.DoubleMatrix#getRowCount() + */ + @Override + public final int getRowCount() { + return numRows; + } + + /* + * (non-Javadoc) + * @see de.jungblut.math.DoubleMatrix#getRowVector(int) + */ + @Override + public final DoubleVector getRowVector(int row) { + return new DenseDoubleVector(getRow(row)); + } + + /* + * (non-Javadoc) + * @see de.jungblut.math.DoubleMatrix#set(int, int, double) + */ + @Override + public final void set(int row, int col, double value) { + this.matrix[row][col] = value; + } + + /** + * Sets the row to a given double array. This does not copy, rather than just + * bends the references. + */ + public final void setRow(int row, double[] value) { + this.matrix[row] = value; + } + + /** + * Sets the column to a given double array. This does not copy, rather than + * just bends the references. + */ + public final void setColumn(int col, double[] values) { + for (int i = 0; i < getRowCount(); i++) { + this.matrix[i][col] = values[i]; + } + } + + /* + * (non-Javadoc) + * @see de.jungblut.math.DoubleMatrix#setColumnVector(int, + * de.jungblut.math.DoubleVector) + */ + @Override + public void setColumnVector(int col, DoubleVector column) { + this.setColumn(col, column.toArray()); + } + + /* + * (non-Javadoc) + * @see de.jungblut.math.DoubleMatrix#setRowVector(int, + * de.jungblut.math.DoubleVector) + */ + @Override + public void setRowVector(int rowIndex, DoubleVector row) { + this.setRow(rowIndex, row.toArray()); + } + + /** + * Returns the size of the matrix as string (ROWSxCOLUMNS). + */ + public String sizeToString() { + return numRows + "x" + numColumns; + } + + /** + * Splits the last column from this matrix. Usually used to get a prediction + * column from some machine learning problem. + * + * @return a tuple of a new sliced matrix and a vector which was the last + * column. + */ + public final Tuple splitLastColumn() { + DenseDoubleMatrix m = new DenseDoubleMatrix(getRowCount(), + getColumnCount() - 1); + for (int i = 0; i < getRowCount(); i++) { + for (int j = 0; j < getColumnCount() - 1; j++) { + m.set(i, j, get(i, j)); + } + } + DenseDoubleVector v = new DenseDoubleVector(getColumn(getColumnCount() - 1)); + return new Tuple(m, v); + } + + /** + * Creates two matrices out of this by the given percentage. It uses a random + * function to determine which rows should belong to the matrix including the + * given percentage amount of rows. + * + * @param percentage A float value between 0.0f and 1.0f + * @return A tuple which includes two matrices, the first contains the + * percentage of the rows from the original matrix (rows are chosen + * randomly) and the second one contains all other rows. + */ + public final Tuple splitRandomMatrices( + float percentage) { + if (percentage < 0.0f || percentage > 1.0f) { + throw new IllegalArgumentException( + "Percentage must be between 0.0 and 1.0! Given " + percentage); + } + + if (percentage == 1.0f) { + return new Tuple(this, null); + } else if (percentage == 0.0f) { + return new Tuple(null, this); + } + + final Random rand = new Random(System.nanoTime()); + int firstMatrixRowsCount = Math.round(percentage * numRows); + + // we first choose needed rows number of items to pick + final HashSet lowerMatrixRowIndices = new HashSet(); + int missingRows = firstMatrixRowsCount; + while (missingRows > 0) { + final int nextIndex = rand.nextInt(numRows); + if (lowerMatrixRowIndices.add(nextIndex)) { + missingRows--; + } + } + + // make to new matrixes + final double[][] firstMatrix = new double[firstMatrixRowsCount][numColumns]; + int firstMatrixIndex = 0; + final double[][] secondMatrix = new double[numRows - firstMatrixRowsCount][numColumns]; + int secondMatrixIndex = 0; + + // then we loop over all items and put split the matrix + for (int r = 0; r < numRows; r++) { + if (lowerMatrixRowIndices.contains(r)) { + firstMatrix[firstMatrixIndex++] = matrix[r]; + } else { + secondMatrix[secondMatrixIndex++] = matrix[r]; + } + } + + return new Tuple( + new DenseDoubleMatrix(firstMatrix), new DenseDoubleMatrix(secondMatrix)); + } + + /* + * (non-Javadoc) + * @see de.jungblut.math.DoubleMatrix#multiply(double) + */ + @Override + public final DenseDoubleMatrix multiply(double scalar) { + DenseDoubleMatrix m = new DenseDoubleMatrix(this.numRows, this.numColumns); + for (int i = 0; i < numRows; i++) { + for (int j = 0; j < numColumns; j++) { + m.set(i, j, this.matrix[i][j] * scalar); + } + } + return m; + } + + /* + * (non-Javadoc) + * @see de.jungblut.math.DoubleMatrix#multiply(de.jungblut.math.DoubleMatrix) + */ + @Override + public final DoubleMatrix multiplyUnsafe(DoubleMatrix other) { + DenseDoubleMatrix matrix = new DenseDoubleMatrix(this.getRowCount(), + other.getColumnCount()); + + final int m = this.numRows; + final int n = this.numColumns; + final int p = other.getColumnCount(); + + for (int j = p; --j >= 0;) { + for (int i = m; --i >= 0;) { + double s = 0; + for (int k = n; --k >= 0;) { + s += get(i, k) * other.get(k, j); + } + matrix.set(i, j, s + matrix.get(i, j)); + } + } + + return matrix; + } + + /* + * (non-Javadoc) + * @see + * de.jungblut.math.DoubleMatrix#multiplyElementWise(de.jungblut.math.DoubleMatrix + * ) + */ + @Override + public final DoubleMatrix multiplyElementWiseUnsafe(DoubleMatrix other) { + DenseDoubleMatrix matrix = new DenseDoubleMatrix(this.numRows, + this.numColumns); + + for (int i = 0; i < numRows; i++) { + for (int j = 0; j < numColumns; j++) { + matrix.set(i, j, this.get(i, j) * (other.get(i, j))); + } + } + + return matrix; + } + + /* + * (non-Javadoc) + * @see + * de.jungblut.math.DoubleMatrix#multiplyVector(de.jungblut.math.DoubleVector) + */ + @Override + public final DoubleVector multiplyVectorUnsafe(DoubleVector v) { + DoubleVector vector = new DenseDoubleVector(this.getRowCount()); + for (int row = 0; row < numRows; row++) { + double sum = 0.0d; + for (int col = 0; col < numColumns; col++) { + sum += (matrix[row][col] * v.get(col)); + } + vector.set(row, sum); + } + + return vector; + } + + /* + * (non-Javadoc) + * @see de.jungblut.math.DoubleMatrix#transpose() + */ + @Override + public DenseDoubleMatrix transpose() { + DenseDoubleMatrix m = new DenseDoubleMatrix(this.numColumns, this.numRows); + for (int i = 0; i < numRows; i++) { + for (int j = 0; j < numColumns; j++) { + m.set(j, i, this.matrix[i][j]); + } + } + return m; + } + + /* + * (non-Javadoc) + * @see de.jungblut.math.DoubleMatrix#subtractBy(double) + */ + @Override + public DenseDoubleMatrix subtractBy(double amount) { + DenseDoubleMatrix m = new DenseDoubleMatrix(this.numRows, this.numColumns); + for (int i = 0; i < numRows; i++) { + for (int j = 0; j < numColumns; j++) { + m.set(i, j, amount - this.matrix[i][j]); + } + } + return m; + } + + /* + * (non-Javadoc) + * @see de.jungblut.math.DoubleMatrix#subtract(double) + */ + @Override + public DenseDoubleMatrix subtract(double amount) { + DenseDoubleMatrix m = new DenseDoubleMatrix(this.numRows, this.numColumns); + for (int i = 0; i < numRows; i++) { + for (int j = 0; j < numColumns; j++) { + m.set(i, j, this.matrix[i][j] - amount); + } + } + return m; + } + + /* + * (non-Javadoc) + * @see de.jungblut.math.DoubleMatrix#subtract(de.jungblut.math.DoubleMatrix) + */ + @Override + public DoubleMatrix subtractUnsafe(DoubleMatrix other) { + DoubleMatrix m = new DenseDoubleMatrix(this.numRows, this.numColumns); + for (int i = 0; i < numRows; i++) { + for (int j = 0; j < numColumns; j++) { + m.set(i, j, this.matrix[i][j] - other.get(i, j)); + } + } + return m; + } + + /* + * (non-Javadoc) + * @see de.jungblut.math.DoubleMatrix#subtract(de.jungblut.math.DoubleVector) + */ + @Override + public DenseDoubleMatrix subtractUnsafe(DoubleVector vec) { + DenseDoubleMatrix cop = new DenseDoubleMatrix(this.getRowCount(), + this.getColumnCount()); + for (int i = 0; i < this.getColumnCount(); i++) { + cop.setColumn(i, getColumnVector(i).subtract(vec.get(i)).toArray()); + } + return cop; + } + + /* + * (non-Javadoc) + * @see de.jungblut.math.DoubleMatrix#divide(de.jungblut.math.DoubleVector) + */ + @Override + public DoubleMatrix divideUnsafe(DoubleVector vec) { + DoubleMatrix cop = new DenseDoubleMatrix(this.getRowCount(), + this.getColumnCount()); + for (int i = 0; i < this.getColumnCount(); i++) { + cop.setColumnVector(i, getColumnVector(i).divide(vec.get(i))); + } + return cop; + } + + /** + * {@inheritDoc} + */ + @Override + public DoubleMatrix divide(DoubleVector vec) { + Preconditions.checkArgument(this.getColumnCount() == vec.getDimension(), + "Dimension mismatch."); + return this.divideUnsafe(vec); + } + + /* + * (non-Javadoc) + * @see de.jungblut.math.DoubleMatrix#divide(de.jungblut.math.DoubleMatrix) + */ + @Override + public DoubleMatrix divideUnsafe(DoubleMatrix other) { + DoubleMatrix m = new DenseDoubleMatrix(this.numRows, this.numColumns); + for (int i = 0; i < numRows; i++) { + for (int j = 0; j < numColumns; j++) { + m.set(i, j, this.matrix[i][j] / other.get(i, j)); + } + } + return m; + } + + @Override + public DoubleMatrix divide(DoubleMatrix other) { + Preconditions.checkArgument(this.getRowCount() == other.getRowCount() + && this.getColumnCount() == other.getColumnCount()); + return divideUnsafe(other); + } + + /* + * (non-Javadoc) + * @see de.jungblut.math.DoubleMatrix#divide(double) + */ + @Override + public DoubleMatrix divide(double scalar) { + DoubleMatrix m = new DenseDoubleMatrix(this.numRows, this.numColumns); + for (int i = 0; i < numRows; i++) { + for (int j = 0; j < numColumns; j++) { + m.set(i, j, this.matrix[i][j] / scalar); + } + } + return m; + } + + /* + * (non-Javadoc) + * @see de.jungblut.math.DoubleMatrix#add(de.jungblut.math.DoubleMatrix) + */ + @Override + public DoubleMatrix add(DoubleMatrix other) { + DoubleMatrix m = new DenseDoubleMatrix(this.numRows, this.numColumns); + for (int i = 0; i < numRows; i++) { + for (int j = 0; j < numColumns; j++) { + m.set(i, j, this.matrix[i][j] + other.get(i, j)); + } + } + return m; + } + + /* + * (non-Javadoc) + * @see de.jungblut.math.DoubleMatrix#pow(int) + */ + @Override + public DoubleMatrix pow(int x) { + DoubleMatrix m = new DenseDoubleMatrix(this.numRows, this.numColumns); + for (int i = 0; i < numRows; i++) { + for (int j = 0; j < numColumns; j++) { + m.set(i, j, Math.pow(matrix[i][j], x)); + } + } + return m; + } + + /* + * (non-Javadoc) + * @see de.jungblut.math.DoubleMatrix#max(int) + */ + @Override + public double max(int column) { + double max = Double.MIN_VALUE; + for (int i = 0; i < getRowCount(); i++) { + double d = matrix[i][column]; + if (d > max) { + max = d; + } + } + return max; + } + + /* + * (non-Javadoc) + * @see de.jungblut.math.DoubleMatrix#min(int) + */ + @Override + public double min(int column) { + double min = Double.MAX_VALUE; + for (int i = 0; i < getRowCount(); i++) { + double d = matrix[i][column]; + if (d < min) { + min = d; + } + } + return min; + } + + /* + * (non-Javadoc) + * @see de.jungblut.math.DoubleMatrix#slice(int, int) + */ + @Override + public DoubleMatrix slice(int rows, int cols) { + return slice(0, rows, 0, cols); + } + + /* + * (non-Javadoc) + * @see de.jungblut.math.DoubleMatrix#slice(int, int, int, int) + */ + @Override + public DoubleMatrix slice(int rowOffset, int rowMax, int colOffset, int colMax) { + DenseDoubleMatrix m = new DenseDoubleMatrix(rowMax - rowOffset, colMax + - colOffset); + for (int row = rowOffset; row < rowMax; row++) { + for (int col = colOffset; col < colMax; col++) { + m.set(row - rowOffset, col - colOffset, this.get(row, col)); + } + } + return m; + } + + /* + * (non-Javadoc) + * @see de.jungblut.math.DoubleMatrix#isSparse() + */ + @Override + public boolean isSparse() { + return false; + } + + /* + * (non-Javadoc) + * @see de.jungblut.math.DoubleMatrix#sum() + */ + @Override + public double sum() { + double x = 0.0d; + for (int i = 0; i < numRows; i++) { + for (int j = 0; j < numColumns; j++) { + x += Math.abs(matrix[i][j]); + } + } + return x; + } + + /* + * (non-Javadoc) + * @see de.jungblut.math.DoubleMatrix#columnIndices() + */ + @Override + public int[] columnIndices() { + int[] x = new int[getColumnCount()]; + for (int i = 0; i < getColumnCount(); i++) + x[i] = i; + return x; + } + + @Override + public int hashCode() { + final int prime = 31; + int result = 1; + result = prime * result + Arrays.hashCode(matrix); + result = prime * result + numColumns; + result = prime * result + numRows; + return result; + } + + @Override + public boolean equals(Object obj) { + if (this == obj) + return true; + if (obj == null) + return false; + if (getClass() != obj.getClass()) + return false; + DenseDoubleMatrix other = (DenseDoubleMatrix) obj; + if (!Arrays.deepEquals(matrix, other.matrix)) + return false; + if (numColumns != other.numColumns) + return false; + return numRows == other.numRows; + } + + @Override + public String toString() { + if (numRows < 10) { + StringBuilder sb = new StringBuilder(); + for (int i = 0; i < numRows; i++) { + sb.append(Arrays.toString(matrix[i])); + sb.append('\n'); + } + return sb.toString(); + } else { + return numRows + "x" + numColumns; + } + } + + /** + * Gets the eye matrix (ones on the main diagonal) with a given dimension. + */ + public static DenseDoubleMatrix eye(int dimension) { + DenseDoubleMatrix m = new DenseDoubleMatrix(dimension, dimension); + + for (int i = 0; i < dimension; i++) { + m.set(i, i, 1); + } + + return m; + } + + /** + * Deep copies the given matrix into a new returned one. + */ + public static DenseDoubleMatrix copy(DenseDoubleMatrix matrix) { + final double[][] src = matrix.getValues(); + final double[][] dest = new double[matrix.getRowCount()][matrix + .getColumnCount()]; + + for (int i = 0; i < dest.length; i++) + System.arraycopy(src[i], 0, dest[i], 0, src[i].length); + + return new DenseDoubleMatrix(dest); + } + + /** + * Some strange function I found in octave but I don't know what it was named. + * It does however multiply the elements from the transposed vector and the + * normal vector and sets it into the according indices of a new constructed + * matrix. + */ + public static DenseDoubleMatrix multiplyTransposedVectors( + DoubleVector transposed, DoubleVector normal) { + DenseDoubleMatrix m = new DenseDoubleMatrix(transposed.getLength(), + normal.getLength()); + for (int row = 0; row < transposed.getLength(); row++) { + for (int col = 0; col < normal.getLength(); col++) { + m.set(row, col, transposed.get(row) * normal.get(col)); + } + } + + return m; + } + + /** + * Just a absolute error function. + */ + public static double error(DenseDoubleMatrix a, DenseDoubleMatrix b) { + return a.subtractUnsafe(b).sum(); + } + + @Override + /** + * {@inheritDoc} + */ + public DoubleMatrix applyToElements(DoubleFunction fun) { + for (int r = 0; r < this.numRows; ++r) { + for (int c = 0; c < this.numColumns; ++c) { + this.set(r, c, fun.apply(this.get(r, c))); + } + } + return this; + } + + @Override + /** + * {@inheritDoc} + */ + public DoubleMatrix applyToElements(DoubleMatrix other, + DoubleDoubleFunction fun) { + Preconditions + .checkArgument(this.numRows == other.getRowCount() + && this.numColumns == other.getColumnCount(), + "Cannot apply double double function to matrices with different sizes."); + + for (int r = 0; r < this.numRows; ++r) { + for (int c = 0; c < this.numColumns; ++c) { + this.set(r, c, fun.apply(this.get(r, c), other.get(r, c))); + } + } + + return this; + } + + /* + * (non-Javadoc) + * @see + * org.apache.hama.ml.math.DoubleMatrix#safeMultiply(org.apache.hama.ml.math + * .DoubleMatrix) + */ + @Override + public DoubleMatrix multiply(DoubleMatrix other) { + Preconditions + .checkArgument( + this.numColumns == other.getRowCount(), + String + .format( + "Matrix with size [%d, %d] cannot multiple matrix with size [%d, %d]", + this.numRows, this.numColumns, other.getRowCount(), + other.getColumnCount())); + + return this.multiplyUnsafe(other); + } + + /* + * (non-Javadoc) + * @see + * org.apache.hama.ml.math.DoubleMatrix#safeMultiplyElementWise(org.apache + * .hama.ml.math.DoubleMatrix) + */ + @Override + public DoubleMatrix multiplyElementWise(DoubleMatrix other) { + Preconditions.checkArgument(this.numRows == other.getRowCount() + && this.numColumns == other.getColumnCount(), + "Matrices with different dimensions cannot be multiplied elementwise."); + return this.multiplyElementWiseUnsafe(other); + } + + /* + * (non-Javadoc) + * @see + * org.apache.hama.ml.math.DoubleMatrix#safeMultiplyVector(org.apache.hama + * .ml.math.DoubleVector) + */ + @Override + public DoubleVector multiplyVector(DoubleVector v) { + Preconditions.checkArgument(this.numColumns == v.getDimension(), + "Dimension mismatch."); + return this.multiplyVectorUnsafe(v); + } + + /* + * (non-Javadoc) + * @see org.apache.hama.ml.math.DoubleMatrix#subtract(org.apache.hama.ml.math. + * DoubleMatrix) + */ + @Override + public DoubleMatrix subtract(DoubleMatrix other) { + Preconditions.checkArgument(this.numRows == other.getRowCount() + && this.numColumns == other.getColumnCount(), "Dimension mismatch."); + return subtractUnsafe(other); + } + + /* + * (non-Javadoc) + * @see org.apache.hama.ml.math.DoubleMatrix#subtract(org.apache.hama.ml.math. + * DoubleVector) + */ + @Override + public DoubleMatrix subtract(DoubleVector vec) { + Preconditions.checkArgument(this.numColumns == vec.getDimension(), + "Dimension mismatch."); + return null; + } + +} Index: commons/src/main/java/org/apache/hama/commons/math/SquaredError.java =================================================================== --- commons/src/main/java/org/apache/hama/commons/math/SquaredError.java (revision 0) +++ commons/src/main/java/org/apache/hama/commons/math/SquaredError.java (revision 0) @@ -0,0 +1,46 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.hama.commons.math; + +/** + * Square error cost function. + * + *
+ * cost(t, y) = 0.5 * (t - y) ˆ 2
+ * 
+ */ +public class SquaredError extends DoubleDoubleFunction { + + @Override + /** + * {@inheritDoc} + */ + public double apply(double target, double actual) { + double diff = target - actual; + return 0.5 * diff * diff; + } + + @Override + /** + * {@inheritDoc} + */ + public double applyDerivative(double target, double actual) { + return actual - target; + } + +} Index: commons/src/main/java/org/apache/hama/commons/math/DoubleFunction.java =================================================================== --- commons/src/main/java/org/apache/hama/commons/math/DoubleFunction.java (revision 0) +++ commons/src/main/java/org/apache/hama/commons/math/DoubleFunction.java (revision 0) @@ -0,0 +1,43 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.hama.commons.math; + +/** + * A double double function takes two arguments. A vector or matrix can apply + * the double function to each element. + * + */ +public abstract class DoubleFunction extends Function { + + /** + * Apply the function to element. + * + * @param elem The element that the function apply to. + * @return The result after applying the function. + */ + public abstract double apply(double value); + + /** + * Apply the gradient of the function. + * + * @param elem + * @return + */ + public abstract double applyDerivative(double value); + +} Index: commons/src/main/java/org/apache/hama/commons/math/Function.java =================================================================== --- commons/src/main/java/org/apache/hama/commons/math/Function.java (revision 0) +++ commons/src/main/java/org/apache/hama/commons/math/Function.java (revision 0) @@ -0,0 +1,33 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.hama.commons.math; + +/** + * A generic function. + * + */ +public abstract class Function { + /** + * Get the name of the function. + * + * @return The name of the function. + */ + final public String getFunctionName() { + return this.getClass().getSimpleName(); + } +} Index: commons/src/main/java/org/apache/hama/commons/math/Tuple.java =================================================================== --- commons/src/main/java/org/apache/hama/commons/math/Tuple.java (revision 0) +++ commons/src/main/java/org/apache/hama/commons/math/Tuple.java (revision 0) @@ -0,0 +1,85 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.hama.commons.math; + +/** + * Tuple class to hold two generic attributes. This class implements hashcode, + * equals and comparable via the first element. + */ +public final class Tuple implements + Comparable> { + + private final FIRST first; + private final SECOND second; + + public Tuple(FIRST first, SECOND second) { + super(); + this.first = first; + this.second = second; + } + + public final FIRST getFirst() { + return first; + } + + public final SECOND getSecond() { + return second; + } + + @Override + public int hashCode() { + final int prime = 31; + int result = 1; + result = prime * result + ((first == null) ? 0 : first.hashCode()); + return result; + } + + @Override + public boolean equals(Object obj) { + if (this == obj) + return true; + if (obj == null) + return false; + if (getClass() != obj.getClass()) + return false; + @SuppressWarnings("rawtypes") + Tuple other = (Tuple) obj; + if (first == null) { + if (other.first != null) + return false; + } else if (!first.equals(other.first)) + return false; + return true; + } + + @SuppressWarnings("unchecked") + @Override + public int compareTo(Tuple o) { + if (o.getFirst() instanceof Comparable && getFirst() instanceof Comparable) { + return ((Comparable) getFirst()).compareTo(o.getFirst()); + } else { + return 0; + } + } + + @Override + public String toString() { + return "Tuple [first=" + first + ", second=" + second + "]"; + } + +} Index: commons/src/main/java/org/apache/hama/commons/math/Tanh.java =================================================================== --- commons/src/main/java/org/apache/hama/commons/math/Tanh.java (revision 0) +++ commons/src/main/java/org/apache/hama/commons/math/Tanh.java (revision 0) @@ -0,0 +1,36 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.hama.commons.math; + +/** + * Tanh function. + * + */ +public class Tanh extends DoubleFunction { + + @Override + public double apply(double value) { + return Math.tanh(value); + } + + @Override + public double applyDerivative(double value) { + return 1 - value * value; + } + +} Index: commons/src/main/java/org/apache/hama/commons/math/Sigmoid.java =================================================================== --- commons/src/main/java/org/apache/hama/commons/math/Sigmoid.java (revision 0) +++ commons/src/main/java/org/apache/hama/commons/math/Sigmoid.java (revision 0) @@ -0,0 +1,39 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.hama.commons.math; + +/** + * The Sigmoid function + * + *
+ * f(x) = 1 / (1 + e^{-x})
+ * 
+ */ +public class Sigmoid extends DoubleFunction { + + @Override + public double apply(double value) { + return 1.0 / (1 + Math.exp(-value)); + } + + @Override + public double applyDerivative(double value) { + return value * (1 - value); + } + +} Index: commons/src/main/java/org/apache/hama/commons/math/DoubleVector.java =================================================================== --- commons/src/main/java/org/apache/hama/commons/math/DoubleVector.java (revision 0) +++ commons/src/main/java/org/apache/hama/commons/math/DoubleVector.java (revision 0) @@ -0,0 +1,388 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.hama.commons.math; + +import java.util.Iterator; + +/** + * Vector with doubles. Some of the operations are mutable, unlike the apply and + * math functions, they return a fresh instance every time. + * + */ +public interface DoubleVector { + + /** + * Retrieves the value at given index. + * + * @param index the index. + * @return a double value at the index. + */ + public double get(int index); + + /** + * Get the length of a vector, for sparse instance it is the actual length. + * (not the dimension!) Always a constant time operation. + * + * @return the length of the vector. + */ + public int getLength(); + + /** + * Get the dimension of a vector, for dense instance it is the same like the + * length, for sparse instances it is usually not the same. Always a constant + * time operation. + * + * @return the dimension of the vector. + */ + public int getDimension(); + + /** + * Set a value at the given index. + * + * @param index the index of the vector to set. + * @param value the value at the index of the vector to set. + */ + public void set(int index, double value); + + /** + * Apply a given {@link DoubleVectorFunction} to this vector and return a new + * one. + * + * @param func the function to apply. + * @return a new vector with the applied function. + */ + @Deprecated + public DoubleVector apply(DoubleVectorFunction func); + + /** + * Apply a given {@link DoubleDoubleVectorFunction} to this vector and the + * other given vector. + * + * @param other the other vector. + * @param func the function to apply on this and the other vector. + * @return a new vector with the result of the function of the two vectors. + */ + @Deprecated + public DoubleVector apply(DoubleVector other, DoubleDoubleVectorFunction func); + + /** + * Apply a given {@link DoubleVectorFunction} to this vector and return a new + * one. + * + * @param func the function to apply. + * @return a new vector with the applied function. + */ + public DoubleVector applyToElements(DoubleFunction func); + + /** + * Apply a given {@link DoubleDoubleVectorFunction} to this vector and the + * other given vector. + * + * @param other the other vector. + * @param func the function to apply on this and the other vector. + * @return a new vector with the result of the function of the two vectors. + */ + public DoubleVector applyToElements(DoubleVector other, + DoubleDoubleFunction func); + + /** + * Adds the given {@link DoubleVector} to this vector. + * + * @param vector the other vector. + * @return a new vector with the sum of both vectors at each element index. + */ + public DoubleVector addUnsafe(DoubleVector vector); + + /** + * Validates the input and adds the given {@link DoubleVector} to this vector. + * + * @param vector the other vector. + * @return a new vector with the sum of both vectors at each element index. + */ + public DoubleVector add(DoubleVector vector); + + /** + * Adds the given scalar to this vector. + * + * @param scalar the scalar. + * @return a new vector with the result at each element index. + */ + public DoubleVector add(double scalar); + + /** + * Subtracts this vector by the given {@link DoubleVector}. + * + * @param vector the other vector. + * @return a new vector with the difference of both vectors. + */ + public DoubleVector subtractUnsafe(DoubleVector vector); + + /** + * Validates the input and subtracts this vector by the given + * {@link DoubleVector}. + * + * @param vector the other vector. + * @return a new vector with the difference of both vectors. + */ + public DoubleVector subtract(DoubleVector vector); + + /** + * Subtracts the given scalar to this vector. (vector - scalar). + * + * @param scalar the scalar. + * @return a new vector with the result at each element index. + */ + public DoubleVector subtract(double scalar); + + /** + * Subtracts the given scalar from this vector. (scalar - vector). + * + * @param scalar the scalar. + * @return a new vector with the result at each element index. + */ + public DoubleVector subtractFrom(double scalar); + + /** + * Multiplies the given scalar to this vector. + * + * @param scalar the scalar. + * @return a new vector with the result of the operation. + */ + public DoubleVector multiply(double scalar); + + /** + * Multiplies the given {@link DoubleVector} with this vector. + * + * @param vector the other vector. + * @return a new vector with the result of the operation. + */ + public DoubleVector multiplyUnsafe(DoubleVector vector); + + /** + * Validates the input and multiplies the given {@link DoubleVector} with this + * vector. + * + * @param vector the other vector. + * @return a new vector with the result of the operation. + */ + public DoubleVector multiply(DoubleVector vector); + + /** + * Validates the input and multiplies the given {@link DoubleMatrix} with this + * vector. + * + * @param matrix + * @return + */ + public DoubleVector multiply(DoubleMatrix matrix); + + /** + * Multiplies the given {@link DoubleMatrix} with this vector. + * + * @param matrix + * @return + */ + public DoubleVector multiplyUnsafe(DoubleMatrix matrix); + + /** + * Divides this vector by the given scalar. (= vector/scalar). + * + * @param scalar the given scalar. + * @return a new vector with the result of the operation. + */ + public DoubleVector divide(double scalar); + + /** + * Divides the given scalar by this vector. (= scalar/vector). + * + * @param scalar the given scalar. + * @return a new vector with the result of the operation. + */ + public DoubleVector divideFrom(double scalar); + + /** + * Powers this vector by the given amount. (=vector^x). + * + * @param x the given exponent. + * @return a new vector with the result of the operation. + */ + public DoubleVector pow(int x); + + /** + * Absolutes the vector at each element. + * + * @return a new vector that does not contain negative values anymore. + */ + public DoubleVector abs(); + + /** + * Square-roots each element. + * + * @return a new vector. + */ + public DoubleVector sqrt(); + + /** + * @return the sum of all elements in this vector. + */ + public double sum(); + + /** + * Calculates the dot product between this vector and the given vector. + * + * @param vector the given vector. + * @return the dot product as a double. + */ + public double dotUnsafe(DoubleVector vector); + + /** + * Validates the input and calculates the dot product between this vector and + * the given vector. + * + * @param vector the given vector. + * @return the dot product as a double. + */ + public double dot(DoubleVector vector); + + /** + * Validates the input and slices this vector from index 0 to the given + * length. + * + * @param length must be > 0 and smaller than the dimension of the vector. + * @return a new vector that is only length long. + */ + public DoubleVector slice(int length); + + /** + * Slices this vector from index 0 to the given length. + * + * @param length must be > 0 and smaller than the dimension of the vector. + * @return a new vector that is only length long. + */ + public DoubleVector sliceUnsafe(int length); + + /** + * Validates the input and then slices this vector from start to end, both are + * INCLUSIVE. For example vec = [0, 1, 2, 3, 4, 5], vec.slice(2, 5) = [2, 3, + * 4, 5]. + * + * @param offset must be > 0 and smaller than the dimension of the vector + * @param length must be > 0 and smaller than the dimension of the vector. + * This must be greater than the offset. + * @return a new vector that is only (length) long. + */ + public DoubleVector slice(int start, int end); + + /** + * Slices this vector from start to end, both are INCLUSIVE. For example vec = + * [0, 1, 2, 3, 4, 5], vec.slice(2, 5) = [2, 3, 4, 5]. + * + * @param offset must be > 0 and smaller than the dimension of the vector + * @param length must be > 0 and smaller than the dimension of the vector. + * This must be greater than the offset. + * @return a new vector that is only (length) long. + */ + public DoubleVector sliceUnsafe(int start, int end); + + /** + * @return the maximum element value in this vector. + */ + public double max(); + + /** + * @return the minimum element value in this vector. + */ + public double min(); + + /** + * @return an array representation of this vector. + */ + public double[] toArray(); + + /** + * @return a fresh new copy of this vector, copies all elements to a new + * vector. (Does not reuse references or stuff). + */ + public DoubleVector deepCopy(); + + /** + * @return an iterator that only iterates over non zero elements. + */ + public Iterator iterateNonZero(); + + /** + * @return an iterator that iterates over all elements. + */ + public Iterator iterate(); + + /** + * @return true if this instance is a sparse vector. Smarter and faster than + * instanceof. + */ + public boolean isSparse(); + + /** + * @return true if this instance is a named vector.Smarter and faster than + * instanceof. + */ + public boolean isNamed(); + + /** + * @return If this vector is a named instance, this will return its name. Or + * null if this is not a named instance. + * + */ + public String getName(); + + /** + * Class for iteration of elements, consists of an index and a value at this + * index. Can be reused for performance purposes. + */ + public static final class DoubleVectorElement { + + private int index; + private double value; + + public DoubleVectorElement() { + super(); + } + + public DoubleVectorElement(int index, double value) { + super(); + this.index = index; + this.value = value; + } + + public final int getIndex() { + return index; + } + + public final double getValue() { + return value; + } + + public final void setIndex(int in) { + this.index = in; + } + + public final void setValue(double in) { + this.value = in; + } + } + +} Index: commons/src/main/java/org/apache/hama/commons/math/DoubleDoubleVectorFunction.java =================================================================== --- commons/src/main/java/org/apache/hama/commons/math/DoubleDoubleVectorFunction.java (revision 0) +++ commons/src/main/java/org/apache/hama/commons/math/DoubleDoubleVectorFunction.java (revision 0) @@ -0,0 +1,35 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.hama.commons.math; + +/** + * A function that can be applied to two double vectors via {@link DoubleVector} + * #apply({@link DoubleVector} v, {@link DoubleDoubleVectorFunction} f); + * + * This class will be replaced by {@link DoubleDoubleFunction} + */ +@Deprecated +public interface DoubleDoubleVectorFunction { + + /** + * Calculates the result of the left and right value of two vectors at a given + * index. + */ + public double calculate(int index, double left, double right); + +} Index: commons/src/main/java/org/apache/hama/commons/math/DoubleMatrix.java =================================================================== --- commons/src/main/java/org/apache/hama/commons/math/DoubleMatrix.java (revision 0) +++ commons/src/main/java/org/apache/hama/commons/math/DoubleMatrix.java (revision 0) @@ -0,0 +1,273 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.hama.commons.math; + +/** + * Standard matrix interface for double elements. Every implementation should + * return a fresh new Matrix when operating with other elements. + */ +public interface DoubleMatrix { + + /** + * Not flagged value for sparse matrices, it is default to 0.0d. + */ + public static final double NOT_FLAGGED = 0.0d; + + /** + * Get a specific value of the matrix. + * + * @return Returns the integer value at in the column at the row. + */ + public double get(int row, int col); + + /** + * Returns the number of columns in the matrix. Always a constant time + * operation. + */ + public int getColumnCount(); + + /** + * Get a whole column of the matrix as vector. + */ + public DoubleVector getColumnVector(int col); + + /** + * Returns the number of rows in this matrix. Always a constant time + * operation. + */ + public int getRowCount(); + + /** + * Get a single row of the matrix as a vector. + */ + public DoubleVector getRowVector(int row); + + /** + * Sets the value at the given row and column index. + */ + public void set(int row, int col, double value); + + /** + * Sets a whole column at index col with the given vector. + */ + public void setColumnVector(int col, DoubleVector column); + + /** + * Sets the whole row at index rowIndex with the given vector. + */ + public void setRowVector(int rowIndex, DoubleVector row); + + /** + * Multiplies this matrix (each element) with the given scalar and returns a + * new matrix. + */ + public DoubleMatrix multiply(double scalar); + + /** + * Multiplies this matrix with the given other matrix. + * + * @param other the other matrix. + * @return + */ + public DoubleMatrix multiplyUnsafe(DoubleMatrix other); + + /** + * Validates the input and multiplies this matrix with the given other matrix. + * + * @param other the other matrix. + * @return + */ + public DoubleMatrix multiply(DoubleMatrix other); + + /** + * Multiplies this matrix per element with a given matrix. + */ + public DoubleMatrix multiplyElementWiseUnsafe(DoubleMatrix other); + + /** + * Validates the input and multiplies this matrix per element with a given + * matrix. + * + * @param other the other matrix + * @return + */ + public DoubleMatrix multiplyElementWise(DoubleMatrix other); + + /** + * Multiplies this matrix with a given vector v. The returning vector contains + * the sum of the rows. + */ + public DoubleVector multiplyVectorUnsafe(DoubleVector v); + + /** + * Multiplies this matrix with a given vector v. The returning vector contains + * the sum of the rows. + * + * @param v the vector + * @return + */ + public DoubleVector multiplyVector(DoubleVector v); + + /** + * Transposes this matrix. + */ + public DoubleMatrix transpose(); + + /** + * Substracts the given amount by each element in this matrix.
+ * = (amount - matrix value) + */ + public DoubleMatrix subtractBy(double amount); + + /** + * Subtracts each element in this matrix by the given amount.
+ * = (matrix value - amount) + */ + public DoubleMatrix subtract(double amount); + + /** + * Subtracts this matrix by the given other matrix. + */ + public DoubleMatrix subtractUnsafe(DoubleMatrix other); + + /** + * Validates the input and subtracts this matrix by the given other matrix. + * + * @param other + * @return + */ + public DoubleMatrix subtract(DoubleMatrix other); + + /** + * Subtracts each element in a column by the related element in the given + * vector. + */ + public DoubleMatrix subtractUnsafe(DoubleVector vec); + + /** + * Validates and subtracts each element in a column by the related element in + * the given vector. + * + * @param vec + * @return + */ + public DoubleMatrix subtract(DoubleVector vec); + + /** + * Divides each element in a column by the related element in the given + * vector. + */ + public DoubleMatrix divideUnsafe(DoubleVector vec); + + /** + * Validates and divides each element in a column by the related element in + * the given vector. + * + * @param vec + * @return + */ + public DoubleMatrix divide(DoubleVector vec); + + /** + * Divides this matrix by the given other matrix. (Per element division). + */ + public DoubleMatrix divideUnsafe(DoubleMatrix other); + + /** + * Validates and divides this matrix by the given other matrix. (Per element + * division). + * + * @param other + * @return + */ + public DoubleMatrix divide(DoubleMatrix other); + + /** + * Divides each element in this matrix by the given scalar. + */ + public DoubleMatrix divide(double scalar); + + /** + * Adds the elements in the given matrix to the elements in this matrix. + */ + public DoubleMatrix add(DoubleMatrix other); + + /** + * Pows each element by the given argument.
+ * = (matrix element^x) + */ + public DoubleMatrix pow(int x); + + /** + * Returns the maximum value of the given column. + */ + public double max(int column); + + /** + * Returns the minimum value of the given column. + */ + public double min(int column); + + /** + * Sums all elements. + */ + public double sum(); + + /** + * Returns an array of column indices existing in this matrix. + */ + public int[] columnIndices(); + + /** + * Returns true if the underlying implementation is sparse. + */ + public boolean isSparse(); + + /** + * Slices the given matrix from 0-rows and from 0-columns. + */ + public DoubleMatrix slice(int rows, int cols); + + /** + * Slices the given matrix from rowOffset-rowMax and from colOffset-colMax. + */ + public DoubleMatrix slice(int rowOffset, int rowMax, int colOffset, int colMax); + + /** + * Apply a double function f(x) onto each element of the matrix. After + * applying, each element of the current matrix will be changed from x to + * f(x). + * + * @param fun The function. + * @return The matrix itself, supply for chain operation. + */ + public DoubleMatrix applyToElements(DoubleFunction fun); + + /** + * Apply a double double function f(x, y) onto each pair of the current matrix + * elements and given matrix. After applying, each element of the current + * matrix will be changed from x to f(x, y). + * + * @param other The matrix contributing the second argument of the function. + * @param fun The function that takes two arguments. + * @return The matrix itself, supply for chain operation. + */ + public DoubleMatrix applyToElements(DoubleMatrix other, + DoubleDoubleFunction fun); + +} Index: commons/src/main/java/org/apache/hama/commons/math/IdentityFunction.java =================================================================== --- commons/src/main/java/org/apache/hama/commons/math/IdentityFunction.java (revision 0) +++ commons/src/main/java/org/apache/hama/commons/math/IdentityFunction.java (revision 0) @@ -0,0 +1,36 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.hama.commons.math; + +/** + * The identity function f(x) = x. + * + */ +public class IdentityFunction extends DoubleFunction { + + @Override + public double apply(double value) { + return value; + } + + @Override + public double applyDerivative(double value) { + return 1; + } + +} Index: commons/src/main/java/org/apache/hama/commons/math/DoubleVectorFunction.java =================================================================== --- commons/src/main/java/org/apache/hama/commons/math/DoubleVectorFunction.java (revision 0) +++ commons/src/main/java/org/apache/hama/commons/math/DoubleVectorFunction.java (revision 0) @@ -0,0 +1,34 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.hama.commons.math; + +/** + * A function that can be applied to a double vector via {@link DoubleVector} + * #apply({@link DoubleVectorFunction} f); + * + * This class will be replaced by {@link DoubleFunction} + */ +@Deprecated +public interface DoubleVectorFunction { + + /** + * Calculates the result with a given index and value of a vector. + */ + public double calculate(int index, double value); + +} Index: commons/src/main/java/org/apache/hama/commons/util/TextPair.java =================================================================== --- commons/src/main/java/org/apache/hama/commons/util/TextPair.java (revision 0) +++ commons/src/main/java/org/apache/hama/commons/util/TextPair.java (revision 0) @@ -0,0 +1,80 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.hama.commons.util; + +import java.io.DataInput; +import java.io.DataOutput; +import java.io.IOException; + +import org.apache.hadoop.io.Text; +import org.apache.hadoop.io.Writable; + +/** + * TextPair class for use in BipartiteMatching algorithm. + * + */ +public final class TextPair implements Writable { + + Text first; + Text second; + + public TextPair() { + first = new Text(); + second = new Text(); + } + + public TextPair(Text first, Text second) { + this.first = first; + this.second = second; + } + + public Text getFirst() { + return first; + } + + public void setFirst(Text first) { + this.first = first; + } + + public Text getSecond() { + return second; + } + + public void setSecond(Text second) { + this.second = second; + } + + @Override + public void write(DataOutput out) throws IOException { + first.write(out); + second.write(out); + } + + @Override + public void readFields(DataInput in) throws IOException { + + first.readFields(in); + second.readFields(in); + } + + @Override + public String toString() { + return first + " " + second; + } + +} Index: commons/src/main/java/org/apache/hama/commons/util/KeyValuePair.java =================================================================== --- commons/src/main/java/org/apache/hama/commons/util/KeyValuePair.java (revision 0) +++ commons/src/main/java/org/apache/hama/commons/util/KeyValuePair.java (revision 0) @@ -0,0 +1,59 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.hama.commons.util; + +/** + * Mutable class for key values. + */ +public class KeyValuePair { + + private K key; + private V value; + + public KeyValuePair() { + + } + + public KeyValuePair(K key, V value) { + super(); + this.key = key; + this.value = value; + } + + public K getKey() { + return key; + } + + public V getValue() { + return value; + } + + public void setKey(K key) { + this.key = key; + } + + public void setValue(V value) { + this.value = value; + } + + public void clear() { + this.key = null; + this.value = null; + } + +} Index: commons/pom.xml =================================================================== --- commons/pom.xml (revision 0) +++ commons/pom.xml (revision 0) @@ -0,0 +1,51 @@ + + + + + + org.apache.hama + hama-parent + 0.7.0-SNAPSHOT + + + 4.0.0 + org.apache.hama + hama-commons + commons + 0.7.0-SNAPSHOT + jar + + + + com.google.guava + guava + 13.0.1 + + + + + hama-commons-${project.version} + + + + maven-surefire-plugin + + + + + Index: examples/src/test/java/org/apache/hama/examples/FastGraphGenTest.java =================================================================== --- examples/src/test/java/org/apache/hama/examples/FastGraphGenTest.java (revision 1535330) +++ examples/src/test/java/org/apache/hama/examples/FastGraphGenTest.java (working copy) @@ -28,7 +28,7 @@ import org.apache.hadoop.io.SequenceFile; import org.apache.hadoop.io.Text; import org.apache.hadoop.io.Writable; -import org.apache.hama.bsp.TextArrayWritable; +import org.apache.hama.commons.io.writable.TextArrayWritable; import org.apache.hama.examples.util.FastGraphGen; import org.junit.Test; Index: examples/src/test/java/org/apache/hama/examples/SpMVTest.java =================================================================== --- examples/src/test/java/org/apache/hama/examples/SpMVTest.java (revision 1535330) +++ examples/src/test/java/org/apache/hama/examples/SpMVTest.java (working copy) @@ -32,8 +32,8 @@ import org.apache.hadoop.io.SequenceFile; import org.apache.hadoop.io.Writable; import org.apache.hama.HamaConfiguration; -import org.apache.hama.examples.util.DenseVectorWritable; -import org.apache.hama.examples.util.SparseVectorWritable; +import org.apache.hama.commons.io.writable.DenseVectorWritable; +import org.apache.hama.commons.io.writable.SparseVectorWritable; import org.junit.Before; import org.junit.Test; Index: examples/src/test/java/org/apache/hama/examples/SymmetricMatrixGenTest.java =================================================================== --- examples/src/test/java/org/apache/hama/examples/SymmetricMatrixGenTest.java (revision 1535330) +++ examples/src/test/java/org/apache/hama/examples/SymmetricMatrixGenTest.java (working copy) @@ -26,7 +26,7 @@ import org.apache.hadoop.io.SequenceFile; import org.apache.hadoop.io.Text; import org.apache.hadoop.io.Writable; -import org.apache.hama.bsp.TextArrayWritable; +import org.apache.hama.commons.io.writable.TextArrayWritable; import org.apache.hama.examples.util.SymmetricMatrixGen; import org.junit.Test; Index: examples/src/test/java/org/apache/hama/examples/BipartiteMatchingTest.java =================================================================== --- examples/src/test/java/org/apache/hama/examples/BipartiteMatchingTest.java (revision 1535330) +++ examples/src/test/java/org/apache/hama/examples/BipartiteMatchingTest.java (working copy) @@ -36,7 +36,7 @@ import org.apache.hadoop.io.Text; import org.apache.hama.HamaConfiguration; import org.apache.hama.bsp.Partitioner; -import org.apache.hama.examples.util.TextPair; +import org.apache.hama.commons.util.TextPair; import org.apache.hama.graph.GraphJob; import org.junit.Test; Index: examples/src/test/java/org/apache/hama/examples/NeuralNetworkTest.java =================================================================== --- examples/src/test/java/org/apache/hama/examples/NeuralNetworkTest.java (revision 1535330) +++ examples/src/test/java/org/apache/hama/examples/NeuralNetworkTest.java (working copy) @@ -31,8 +31,8 @@ import org.apache.hadoop.io.LongWritable; import org.apache.hadoop.io.SequenceFile; import org.apache.hama.HamaConfiguration; -import org.apache.hama.ml.math.DenseDoubleVector; -import org.apache.hama.ml.writable.VectorWritable; +import org.apache.hama.commons.io.writable.VectorWritable; +import org.apache.hama.commons.math.DenseDoubleVector; /** * Test the functionality of NeuralNetwork Example. Index: examples/src/main/java/org/apache/hama/examples/Kmeans.java =================================================================== --- examples/src/main/java/org/apache/hama/examples/Kmeans.java (revision 1535330) +++ examples/src/main/java/org/apache/hama/examples/Kmeans.java (working copy) @@ -22,8 +22,8 @@ import org.apache.hadoop.fs.Path; import org.apache.hadoop.io.NullWritable; import org.apache.hama.bsp.BSPJob; +import org.apache.hama.commons.io.writable.VectorWritable; import org.apache.hama.ml.kmeans.KMeansBSP; -import org.apache.hama.ml.writable.VectorWritable; /** * Uses the {@link KMeansBSP} class to run a Kmeans Clustering with BSP. You can Index: examples/src/main/java/org/apache/hama/examples/GradientDescentExample.java =================================================================== --- examples/src/main/java/org/apache/hama/examples/GradientDescentExample.java (revision 1535330) +++ examples/src/main/java/org/apache/hama/examples/GradientDescentExample.java (working copy) @@ -29,11 +29,11 @@ import org.apache.hama.bsp.BSPJob; import org.apache.hama.bsp.FileOutputFormat; import org.apache.hama.bsp.TextOutputFormat; +import org.apache.hama.commons.io.writable.VectorWritable; import org.apache.hama.ml.regression.GradientDescentBSP; import org.apache.hama.ml.regression.LogisticRegressionModel; import org.apache.hama.ml.regression.RegressionModel; import org.apache.hama.ml.regression.VectorDoubleFileInputFormat; -import org.apache.hama.ml.writable.VectorWritable; /** * A {@link GradientDescentBSP} job example Index: examples/src/main/java/org/apache/hama/examples/SpMV.java =================================================================== --- examples/src/main/java/org/apache/hama/examples/SpMV.java (revision 1535330) +++ examples/src/main/java/org/apache/hama/examples/SpMV.java (working copy) @@ -42,9 +42,9 @@ import org.apache.hama.bsp.SequenceFileInputFormat; import org.apache.hama.bsp.SequenceFileOutputFormat; import org.apache.hama.bsp.sync.SyncException; -import org.apache.hama.examples.util.DenseVectorWritable; -import org.apache.hama.examples.util.SparseVectorWritable; -import org.apache.hama.util.KeyValuePair; +import org.apache.hama.commons.io.writable.DenseVectorWritable; +import org.apache.hama.commons.io.writable.SparseVectorWritable; +import org.apache.hama.commons.util.KeyValuePair; /** * Sparse matrix vector multiplication. Currently it uses row-wise access. Index: examples/src/main/java/org/apache/hama/examples/PageRank.java =================================================================== --- examples/src/main/java/org/apache/hama/examples/PageRank.java (revision 1535330) +++ examples/src/main/java/org/apache/hama/examples/PageRank.java (working copy) @@ -28,8 +28,8 @@ import org.apache.hama.HamaConfiguration; import org.apache.hama.bsp.HashPartitioner; import org.apache.hama.bsp.SequenceFileInputFormat; -import org.apache.hama.bsp.TextArrayWritable; import org.apache.hama.bsp.TextOutputFormat; +import org.apache.hama.commons.io.writable.TextArrayWritable; import org.apache.hama.graph.AverageAggregator; import org.apache.hama.graph.Edge; import org.apache.hama.graph.GraphJob; Index: examples/src/main/java/org/apache/hama/examples/InlinkCount.java =================================================================== --- examples/src/main/java/org/apache/hama/examples/InlinkCount.java (revision 1535330) +++ examples/src/main/java/org/apache/hama/examples/InlinkCount.java (working copy) @@ -26,8 +26,8 @@ import org.apache.hama.HamaConfiguration; import org.apache.hama.bsp.HashPartitioner; import org.apache.hama.bsp.SequenceFileOutputFormat; -import org.apache.hama.bsp.TextArrayWritable; import org.apache.hama.bsp.TextInputFormat; +import org.apache.hama.commons.io.writable.TextArrayWritable; import org.apache.hama.graph.GraphJob; import org.apache.hama.graph.Vertex; Index: examples/src/main/java/org/apache/hama/examples/util/FastGraphGen.java =================================================================== --- examples/src/main/java/org/apache/hama/examples/util/FastGraphGen.java (revision 1535330) +++ examples/src/main/java/org/apache/hama/examples/util/FastGraphGen.java (working copy) @@ -34,8 +34,8 @@ import org.apache.hama.bsp.FileOutputFormat; import org.apache.hama.bsp.NullInputFormat; import org.apache.hama.bsp.SequenceFileOutputFormat; -import org.apache.hama.bsp.TextArrayWritable; import org.apache.hama.bsp.sync.SyncException; +import org.apache.hama.commons.io.writable.TextArrayWritable; import com.google.common.collect.Sets; Index: examples/src/main/java/org/apache/hama/examples/util/SparseVectorWritable.java =================================================================== --- examples/src/main/java/org/apache/hama/examples/util/SparseVectorWritable.java (revision 1535330) +++ examples/src/main/java/org/apache/hama/examples/util/SparseVectorWritable.java (working copy) @@ -1,105 +0,0 @@ -/** - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.hama.examples.util; - -import java.io.DataInput; -import java.io.DataOutput; -import java.io.IOException; -import java.util.ArrayList; -import java.util.List; - -import org.apache.hadoop.io.Writable; - -/** - * This class represents sparse vector. It will give improvement in memory - * consumption in case of vectors which sparsity is close to zero. Can be used - * in SpMV for representing input matrix rows efficiently. Internally represents - * values as list of indeces and list of values. - */ -public class SparseVectorWritable implements Writable { - - private Integer size; - private List indeces; - private List values; - - public SparseVectorWritable() { - indeces = new ArrayList(); - values = new ArrayList(); - } - - public void clear() { - indeces = new ArrayList(); - values = new ArrayList(); - } - - public void addCell(int index, double value) { - indeces.add(index); - values.add(value); - } - - public void setSize(int size) { - this.size = size; - } - - public int getSize() { - if (size != null) - return size; - return indeces.size(); - } - - public List getIndeces() { - return indeces; - } - - public List getValues() { - return values; - } - - @Override - public void readFields(DataInput in) throws IOException { - clear(); - int size = in.readInt(); - int len = in.readInt(); - setSize(size); - for (int i = 0; i < len; i++) { - int index = in.readInt(); - double value = in.readDouble(); - this.addCell(index, value); - } - } - - @Override - public void write(DataOutput out) throws IOException { - out.writeInt(getSize()); - out.writeInt(indeces.size()); - for (int i = 0; i < indeces.size(); i++) { - out.writeInt(indeces.get(i)); - out.writeDouble(values.get(i)); - } - } - - @Override - public String toString() { - StringBuilder st = new StringBuilder(); - st.append(" " + getSize() + " " + indeces.size()); - for (int i = 0; i < indeces.size(); i++) - st.append(" " + indeces.get(i) + " " + values.get(i)); - return st.toString(); - } - -} Index: examples/src/main/java/org/apache/hama/examples/util/DenseVectorWritable.java =================================================================== --- examples/src/main/java/org/apache/hama/examples/util/DenseVectorWritable.java (revision 1535330) +++ examples/src/main/java/org/apache/hama/examples/util/DenseVectorWritable.java (working copy) @@ -1,87 +0,0 @@ -/** - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.hama.examples.util; - -import java.io.DataInput; -import java.io.DataOutput; -import java.io.IOException; - -import org.apache.hadoop.io.Writable; - -/** - * This class represents dense vector. It will improve memory consumption up to - * two times in comparison to SparseVectorWritable in case of vectors which - * sparsity is close to 1. Internally represents vector values as array. Can be - * used in SpMV for representation of input and output vector. - */ -public class DenseVectorWritable implements Writable { - - private double values[]; - - public DenseVectorWritable() { - values = new double[0]; - } - - public int getSize() { - return values.length; - } - - public void setSize(int size) { - values = new double[size]; - } - - public double get(int index) { - return values[index]; - } - - public void addCell(int index, double value) { - values[index] = value; - } - - @Override - public void readFields(DataInput in) throws IOException { - int size = in.readInt(); - int len = in.readInt(); - setSize(size); - for (int i = 0; i < len; i++) { - int index = in.readInt(); - double value = in.readDouble(); - values[index] = value; - } - } - - @Override - public void write(DataOutput out) throws IOException { - out.writeInt(getSize()); - out.writeInt(getSize()); - for (int i = 0; i < getSize(); i++) { - out.writeInt(i); - out.writeDouble(values[i]); - } - } - - @Override - public String toString() { - StringBuilder st = new StringBuilder(); - st.append(" " + getSize() + " " + getSize()); - for (int i = 0; i < getSize(); i++) - st.append(" " + i + " " + values[i]); - return st.toString(); - } - -} Index: examples/src/main/java/org/apache/hama/examples/util/SymmetricMatrixGen.java =================================================================== --- examples/src/main/java/org/apache/hama/examples/util/SymmetricMatrixGen.java (revision 1535330) +++ examples/src/main/java/org/apache/hama/examples/util/SymmetricMatrixGen.java (working copy) @@ -37,8 +37,8 @@ import org.apache.hama.bsp.FileOutputFormat; import org.apache.hama.bsp.NullInputFormat; import org.apache.hama.bsp.SequenceFileOutputFormat; -import org.apache.hama.bsp.TextArrayWritable; import org.apache.hama.bsp.sync.SyncException; +import org.apache.hama.commons.io.writable.TextArrayWritable; import org.apache.hama.examples.CombineExample; public class SymmetricMatrixGen { Index: examples/src/main/java/org/apache/hama/examples/util/TextPair.java =================================================================== --- examples/src/main/java/org/apache/hama/examples/util/TextPair.java (revision 1535330) +++ examples/src/main/java/org/apache/hama/examples/util/TextPair.java (working copy) @@ -1,80 +0,0 @@ -/** - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.hama.examples.util; - -import java.io.DataInput; -import java.io.DataOutput; -import java.io.IOException; - -import org.apache.hadoop.io.Text; -import org.apache.hadoop.io.Writable; - -/** - * TextPair class for use in BipartiteMatching algorithm. - * - */ -public final class TextPair implements Writable { - - Text first; - Text second; - - public TextPair() { - first = new Text(); - second = new Text(); - } - - public TextPair(Text first, Text second) { - this.first = first; - this.second = second; - } - - public Text getFirst() { - return first; - } - - public void setFirst(Text first) { - this.first = first; - } - - public Text getSecond() { - return second; - } - - public void setSecond(Text second) { - this.second = second; - } - - @Override - public void write(DataOutput out) throws IOException { - first.write(out); - second.write(out); - } - - @Override - public void readFields(DataInput in) throws IOException { - - first.readFields(in); - second.readFields(in); - } - - @Override - public String toString() { - return first + " " + second; - } - -} Index: examples/src/main/java/org/apache/hama/examples/BipartiteMatching.java =================================================================== --- examples/src/main/java/org/apache/hama/examples/BipartiteMatching.java (revision 1535330) +++ examples/src/main/java/org/apache/hama/examples/BipartiteMatching.java (working copy) @@ -33,7 +33,7 @@ import org.apache.hama.bsp.HashPartitioner; import org.apache.hama.bsp.TextInputFormat; import org.apache.hama.bsp.TextOutputFormat; -import org.apache.hama.examples.util.TextPair; +import org.apache.hama.commons.util.TextPair; import org.apache.hama.graph.Edge; import org.apache.hama.graph.GraphJob; import org.apache.hama.graph.Vertex; Index: examples/src/main/java/org/apache/hama/examples/NeuralNetwork.java =================================================================== --- examples/src/main/java/org/apache/hama/examples/NeuralNetwork.java (revision 1535330) +++ examples/src/main/java/org/apache/hama/examples/NeuralNetwork.java (working copy) @@ -28,10 +28,10 @@ import org.apache.hadoop.conf.Configuration; import org.apache.hadoop.fs.FileSystem; import org.apache.hadoop.fs.Path; +import org.apache.hama.commons.math.DenseDoubleVector; +import org.apache.hama.commons.math.DoubleVector; +import org.apache.hama.commons.math.FunctionFactory; import org.apache.hama.ml.ann.SmallLayeredNeuralNetwork; -import org.apache.hama.ml.math.DenseDoubleVector; -import org.apache.hama.ml.math.DoubleVector; -import org.apache.hama.ml.math.FunctionFactory; /** * The example of using {@link SmallLayeredNeuralNetwork}, including the Index: examples/pom.xml =================================================================== --- examples/pom.xml (revision 1535330) +++ examples/pom.xml (working copy) @@ -33,6 +33,11 @@ org.apache.hama + hama-commons + ${project.version} + + + org.apache.hama hama-core ${project.version} Index: src/assemble/bin.xml =================================================================== --- src/assemble/bin.xml (revision 1535330) +++ src/assemble/bin.xml (working copy) @@ -27,6 +27,18 @@ + ../commons/target + + hama-*.jar + + + *sources.jar + *tests.jar + *javadoc.jar + + ../hama-${project.version}/ + + ../core/target hama-*.jar