diff --git a/src/main/java/org/apache/hadoop/hbase/client/coprocessor/Exec.java b/src/main/java/org/apache/hadoop/hbase/client/coprocessor/Exec.java index c127ea3..428dba6 100644 --- a/src/main/java/org/apache/hadoop/hbase/client/coprocessor/Exec.java +++ b/src/main/java/org/apache/hadoop/hbase/client/coprocessor/Exec.java @@ -26,6 +26,7 @@ import org.apache.hadoop.hbase.io.HbaseObjectWritable; import org.apache.hadoop.hbase.ipc.CoprocessorProtocol; import org.apache.hadoop.hbase.ipc.Invocation; import org.apache.hadoop.hbase.util.Bytes; +import org.apache.hadoop.hbase.util.Classes; import java.io.DataInput; import java.io.DataOutput; @@ -83,14 +84,37 @@ public class Exec extends Invocation implements Row { @Override public void write(DataOutput out) throws IOException { - super.write(out); + // fields for Invocation + out.writeUTF(this.methodName); + out.writeInt(parameterClasses.length); + for (int i = 0; i < parameterClasses.length; i++) { + HbaseObjectWritable.writeObject(out, parameters[i], parameters[i].getClass(), + conf); + out.writeUTF(parameterClasses[i].getName()); + } + // fields for Exec Bytes.writeByteArray(out, referenceRow); out.writeUTF(protocol.getName()); } @Override public void readFields(DataInput in) throws IOException { - super.readFields(in); + // fields for Invocation + methodName = in.readUTF(); + parameters = new Object[in.readInt()]; + parameterClasses = new Class[parameters.length]; + HbaseObjectWritable objectWritable = new HbaseObjectWritable(); + for (int i = 0; i < parameters.length; i++) { + parameters[i] = HbaseObjectWritable.readObject(in, objectWritable, + this.conf); + String parameterClassName = in.readUTF(); + try { + parameterClasses[i] = Classes.extendedForName(parameterClassName); + } catch (ClassNotFoundException e) { + throw new IOException("Couldn't find class: " + parameterClassName); + } + } + // fields for Exec referenceRow = Bytes.readByteArray(in); String protocolName = in.readUTF(); try { diff --git a/src/main/java/org/apache/hadoop/hbase/client/coprocessor/ExecResult.java b/src/main/java/org/apache/hadoop/hbase/client/coprocessor/ExecResult.java index be46cd2..fe422c8 100644 --- a/src/main/java/org/apache/hadoop/hbase/client/coprocessor/ExecResult.java +++ b/src/main/java/org/apache/hadoop/hbase/client/coprocessor/ExecResult.java @@ -21,11 +21,13 @@ package org.apache.hadoop.hbase.client.coprocessor; import org.apache.hadoop.hbase.io.HbaseObjectWritable; import org.apache.hadoop.hbase.util.Bytes; +import org.apache.hadoop.hbase.util.Classes; import org.apache.hadoop.io.Writable; import java.io.DataInput; import java.io.DataOutput; import java.io.IOException; +import java.io.Serializable; /** * Represents the return value from a @@ -70,12 +72,25 @@ public class ExecResult implements Writable { public void write(DataOutput out) throws IOException { Bytes.writeByteArray(out, regionName); HbaseObjectWritable.writeObject(out, value, - (valueType != null ? valueType : Writable.class), null); + value.getClass(), null); + Class alternativeSerializationClass; + if(value instanceof Writable){ + alternativeSerializationClass = Writable.class; + } else { + alternativeSerializationClass = Serializable.class; + } + out.writeUTF((valueType != null ? valueType : alternativeSerializationClass).getName()); } @Override public void readFields(DataInput in) throws IOException { regionName = Bytes.readByteArray(in); value = HbaseObjectWritable.readObject(in, null); + String className = in.readUTF(); + try { + valueType = Classes.extendedForName(className); + } catch (ClassNotFoundException e) { + throw new IOException("Unable to find class of type: " + className ); + } } } diff --git a/src/main/java/org/apache/hadoop/hbase/ipc/Invocation.java b/src/main/java/org/apache/hadoop/hbase/ipc/Invocation.java index 9609652..e60f970 100644 --- a/src/main/java/org/apache/hadoop/hbase/ipc/Invocation.java +++ b/src/main/java/org/apache/hadoop/hbase/ipc/Invocation.java @@ -31,11 +31,11 @@ import java.lang.reflect.Method; /** A method invocation, including the method name and its parameters.*/ public class Invocation implements Writable, Configurable { - private String methodName; + protected String methodName; @SuppressWarnings("unchecked") - private Class[] parameterClasses; - private Object[] parameters; - private Configuration conf; + protected Class[] parameterClasses; + protected Object[] parameters; + protected Configuration conf; public Invocation() {} diff --git a/src/main/java/org/apache/hadoop/hbase/util/Classes.java b/src/main/java/org/apache/hadoop/hbase/util/Classes.java new file mode 100644 index 0000000..9313da8 --- /dev/null +++ b/src/main/java/org/apache/hadoop/hbase/util/Classes.java @@ -0,0 +1,44 @@ +package org.apache.hadoop.hbase.util; + +/** + * Utilities for class manipulation. + */ +public class Classes { + + /** + * Equivalent of {@link Class#forName(String)} which also returns classes for + * primitives like boolean, etc. + * + * @param className + * The name of the class to retrieve. Can be either a normal class or + * a primitive class. + * @return The class specified by className + * @throws ClassNotFoundException + * If the requested class can not be found. + */ + public static Class extendedForName(String className) + throws ClassNotFoundException { + Class valueType; + if (className.equals("boolean")) { + valueType = boolean.class; + } else if (className.equals("byte")) { + valueType = byte.class; + } else if (className.equals("short")) { + valueType = short.class; + } else if (className.equals("int")) { + valueType = int.class; + } else if (className.equals("long")) { + valueType = long.class; + } else if (className.equals("float")) { + valueType = float.class; + } else if (className.equals("double")) { + valueType = double.class; + } else if (className.equals("char")) { + valueType = char.class; + } else { + valueType = Class.forName(className); + } + return valueType; + } + +} diff --git a/src/test/java/org/apache/hadoop/hbase/coprocessor/GenericEndpoint.java b/src/test/java/org/apache/hadoop/hbase/coprocessor/GenericEndpoint.java new file mode 100644 index 0000000..5e835f0 --- /dev/null +++ b/src/test/java/org/apache/hadoop/hbase/coprocessor/GenericEndpoint.java @@ -0,0 +1,11 @@ +package org.apache.hadoop.hbase.coprocessor; + +public class GenericEndpoint extends BaseEndpointCoprocessor implements + GenericProtocol { + + @Override + public T doWork(T genericObject) { + return genericObject; + } + +} diff --git a/src/test/java/org/apache/hadoop/hbase/coprocessor/GenericProtocol.java b/src/test/java/org/apache/hadoop/hbase/coprocessor/GenericProtocol.java new file mode 100644 index 0000000..ddb30d4 --- /dev/null +++ b/src/test/java/org/apache/hadoop/hbase/coprocessor/GenericProtocol.java @@ -0,0 +1,17 @@ +package org.apache.hadoop.hbase.coprocessor; + +import org.apache.hadoop.hbase.ipc.CoprocessorProtocol; + +public interface GenericProtocol extends CoprocessorProtocol { + + /** + * Simple interface to allow the passing of a generic parameter to see if the + * RPC framework can accommodate generics. + * + * @param + * @param genericObject + * @return + */ + public T doWork(T genericObject); + +} diff --git a/src/test/java/org/apache/hadoop/hbase/coprocessor/TestCoprocessorEndpoint.java b/src/test/java/org/apache/hadoop/hbase/coprocessor/TestCoprocessorEndpoint.java index 75f76e8..334f5a5 100644 --- a/src/test/java/org/apache/hadoop/hbase/coprocessor/TestCoprocessorEndpoint.java +++ b/src/test/java/org/apache/hadoop/hbase/coprocessor/TestCoprocessorEndpoint.java @@ -19,17 +19,25 @@ */ package org.apache.hadoop.hbase.coprocessor; -import org.apache.hadoop.hbase.*; -import org.apache.hadoop.hbase.client.*; -import org.apache.hadoop.hbase.client.coprocessor.*; -import org.apache.hadoop.hbase.util.Bytes; -import org.junit.*; -import org.apache.hadoop.conf.Configuration; - -import static org.junit.Assert.*; +import static org.junit.Assert.assertArrayEquals; +import static org.junit.Assert.assertEquals; -import java.util.Map; import java.io.IOException; +import java.util.Map; + +import org.apache.hadoop.conf.Configuration; +import org.apache.hadoop.hbase.HBaseTestingUtility; +import org.apache.hadoop.hbase.HConstants; +import org.apache.hadoop.hbase.MiniHBaseCluster; +import org.apache.hadoop.hbase.client.HTable; +import org.apache.hadoop.hbase.client.Put; +import org.apache.hadoop.hbase.client.Scan; +import org.apache.hadoop.hbase.client.coprocessor.Batch; +import org.apache.hadoop.hbase.util.Bytes; +import org.apache.hadoop.io.Text; +import org.junit.AfterClass; +import org.junit.BeforeClass; +import org.junit.Test; /** * TestEndpoint: test cases to verify coprocessor Endpoint @@ -39,12 +47,12 @@ public class TestCoprocessorEndpoint { private static final byte[] TEST_TABLE = Bytes.toBytes("TestTable"); private static final byte[] TEST_FAMILY = Bytes.toBytes("TestFamily"); private static final byte[] TEST_QUALIFIER = Bytes.toBytes("TestQualifier"); - private static byte [] ROW = Bytes.toBytes("testRow"); + private static byte[] ROW = Bytes.toBytes("testRow"); private static final int ROWSIZE = 20; private static final int rowSeperator1 = 5; private static final int rowSeperator2 = 12; - private static byte [][] ROWS = makeN(ROW, ROWSIZE); + private static byte[][] ROWS = makeN(ROW, ROWSIZE); private static HBaseTestingUtility util = new HBaseTestingUtility(); private static MiniHBaseCluster cluster = null; @@ -53,18 +61,19 @@ public class TestCoprocessorEndpoint { public static void setupBeforeClass() throws Exception { // set configure to indicate which cp should be loaded Configuration conf = util.getConfiguration(); - conf.set(CoprocessorHost.REGION_COPROCESSOR_CONF_KEY, - "org.apache.hadoop.hbase.coprocessor.ColumnAggregationEndpoint"); + conf.setStrings(CoprocessorHost.REGION_COPROCESSOR_CONF_KEY, + "org.apache.hadoop.hbase.coprocessor.ColumnAggregationEndpoint", + "org.apache.hadoop.hbase.coprocessor.GenericEndpoint"); util.startMiniCluster(2); cluster = util.getMiniHBaseCluster(); HTable table = util.createTable(TEST_TABLE, TEST_FAMILY); util.createMultiRegions(util.getConfiguration(), table, TEST_FAMILY, - new byte[][]{ HConstants.EMPTY_BYTE_ARRAY, ROWS[rowSeperator1], - ROWS[rowSeperator2]}); + new byte[][] { HConstants.EMPTY_BYTE_ARRAY, + ROWS[rowSeperator1], ROWS[rowSeperator2] }); - for(int i = 0; i < ROWSIZE; i++) { + for (int i = 0; i < ROWSIZE; i++) { Put put = new Put(ROWS[i]); put.add(TEST_FAMILY, TEST_QUALIFIER, Bytes.toBytes(i)); table.put(put); @@ -80,27 +89,56 @@ public class TestCoprocessorEndpoint { } @Test + public void testGeneric() throws Throwable { + HTable table = new HTable(util.getConfiguration(), TEST_TABLE); + GenericProtocol protocol = table.coprocessorProxy(GenericProtocol.class, + Bytes.toBytes("testRow")); + String workResult1 = protocol.doWork("foo"); + assertEquals("foo", workResult1); + byte[] workResult2 = protocol.doWork(new byte[]{1}); + assertArrayEquals(new byte[]{1}, workResult2); + byte workResult3 = protocol.doWork((byte)1); + assertEquals((byte)1, workResult3); + char workResult4 = protocol.doWork('c'); + assertEquals('c', workResult4); + boolean workResult5 = protocol.doWork(true); + assertEquals(true, workResult5); + short workResult6 = protocol.doWork((short)1); + assertEquals((short)1, workResult6); + int workResult7 = protocol.doWork(5); + assertEquals(5, workResult7); + long workResult8 = protocol.doWork(5l); + assertEquals(5l, workResult8); + double workResult9 = protocol.doWork(6d); + assertEquals(6d, workResult9, 0.01); + float workResult10 = protocol.doWork(6f); + assertEquals(6f, workResult10, 0.01); + Text workResult11 = protocol.doWork(new Text("foo")); + assertEquals(new Text("foo"), workResult11); + } + + @Test public void testAggregation() throws Throwable { HTable table = new HTable(util.getConfiguration(), TEST_TABLE); Scan scan; Map results; // scan: for all regions - results = table.coprocessorExec(ColumnAggregationProtocol.class, - ROWS[rowSeperator1 - 1], - ROWS[rowSeperator2 + 1], - new Batch.Call() { - public Long call(ColumnAggregationProtocol instance) - throws IOException{ - return instance.sum(TEST_FAMILY, TEST_QUALIFIER); - } - }); + results = table + .coprocessorExec(ColumnAggregationProtocol.class, + ROWS[rowSeperator1 - 1], ROWS[rowSeperator2 + 1], + new Batch.Call() { + public Long call(ColumnAggregationProtocol instance) + throws IOException { + return instance.sum(TEST_FAMILY, TEST_QUALIFIER); + } + }); int sumResult = 0; int expectedResult = 0; for (Map.Entry e : results.entrySet()) { sumResult += e.getValue(); } - for(int i = 0;i < ROWSIZE; i++) { + for (int i = 0; i < ROWSIZE; i++) { expectedResult += i; } assertEquals("Invalid result", sumResult, expectedResult); @@ -108,29 +146,29 @@ public class TestCoprocessorEndpoint { results.clear(); // scan: for region 2 and region 3 - results = table.coprocessorExec(ColumnAggregationProtocol.class, - ROWS[rowSeperator1 + 1], - ROWS[rowSeperator2 + 1], - new Batch.Call() { - public Long call(ColumnAggregationProtocol instance) - throws IOException{ - return instance.sum(TEST_FAMILY, TEST_QUALIFIER); - } - }); + results = table + .coprocessorExec(ColumnAggregationProtocol.class, + ROWS[rowSeperator1 + 1], ROWS[rowSeperator2 + 1], + new Batch.Call() { + public Long call(ColumnAggregationProtocol instance) + throws IOException { + return instance.sum(TEST_FAMILY, TEST_QUALIFIER); + } + }); sumResult = 0; expectedResult = 0; for (Map.Entry e : results.entrySet()) { sumResult += e.getValue(); } - for(int i = rowSeperator1;i < ROWSIZE; i++) { + for (int i = rowSeperator1; i < ROWSIZE; i++) { expectedResult += i; } assertEquals("Invalid result", sumResult, expectedResult); } - private static byte [][] makeN(byte [] base, int n) { - byte [][] ret = new byte[n][]; - for(int i=0;i