diff --git ql/src/java/org/apache/hadoop/hive/ql/exec/tez/CustomPartitionVertex.java ql/src/java/org/apache/hadoop/hive/ql/exec/tez/CustomPartitionVertex.java index 45d3cd1..8974e9b 100644 --- ql/src/java/org/apache/hadoop/hive/ql/exec/tez/CustomPartitionVertex.java +++ ql/src/java/org/apache/hadoop/hive/ql/exec/tez/CustomPartitionVertex.java @@ -493,7 +493,7 @@ private void processAllEvents(String inputName, UserPayload getBytePayload(Multimap routingTable) throws IOException { CustomEdgeConfiguration edgeConf = - new CustomEdgeConfiguration(routingTable.keySet().size(), routingTable); + new CustomEdgeConfiguration(numBuckets, routingTable); DataOutputBuffer dob = new DataOutputBuffer(); edgeConf.write(dob); byte[] serialized = dob.getData(); diff --git ql/src/test/org/apache/hadoop/hive/ql/exec/tez/TestCustomPartitionVertex.java ql/src/test/org/apache/hadoop/hive/ql/exec/tez/TestCustomPartitionVertex.java new file mode 100644 index 0000000..dbdd955 --- /dev/null +++ ql/src/test/org/apache/hadoop/hive/ql/exec/tez/TestCustomPartitionVertex.java @@ -0,0 +1,43 @@ +package org.apache.hadoop.hive.ql.exec.tez; + +import com.google.common.collect.HashMultimap; +import com.google.common.collect.Multimap; +import org.apache.hadoop.hive.ql.plan.TezWork; +import org.apache.hadoop.io.DataOutputBuffer; +import org.apache.tez.dag.api.UserPayload; +import org.apache.tez.dag.api.VertexManagerPluginContext; +import org.junit.Test; + +import java.io.IOException; +import java.nio.ByteBuffer; + +import static org.junit.Assert.assertEquals; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +public class TestCustomPartitionVertex { + @Test(timeout = 5000) + public void testGetBytePayload() throws IOException { + int numBuckets = 10; + VertexManagerPluginContext context = mock(VertexManagerPluginContext.class); + CustomVertexConfiguration vertexConf = + new CustomVertexConfiguration(numBuckets, TezWork.VertexType.INITIALIZED_EDGES); + DataOutputBuffer dob = new DataOutputBuffer(); + vertexConf.write(dob); + UserPayload payload = UserPayload.create(ByteBuffer.wrap(dob.getData())); + when(context.getUserPayload()).thenReturn(payload); + + CustomPartitionVertex vm = new CustomPartitionVertex(context); + vm.initialize(); + + // prepare empty routing table + Multimap routingTable = HashMultimap. create(); + payload = vm.getBytePayload(routingTable); + // get conf from user payload + CustomEdgeConfiguration edgeConf = new CustomEdgeConfiguration(); + DataInputByteBuffer dibb = new DataInputByteBuffer(); + dibb.reset(payload.getPayload()); + edgeConf.readFields(dibb); + assertEquals(numBuckets, edgeConf.getNumBuckets()); + } +}