diff --git a/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-api/src/main/java/org/apache/hadoop/yarn/factories/RecordFactory.java b/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-api/src/main/java/org/apache/hadoop/yarn/factories/RecordFactory.java index bd95c6b..dbd838e 100644 --- a/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-api/src/main/java/org/apache/hadoop/yarn/factories/RecordFactory.java +++ b/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-api/src/main/java/org/apache/hadoop/yarn/factories/RecordFactory.java @@ -19,10 +19,22 @@ package org.apache.hadoop.yarn.factories; import org.apache.hadoop.classification.InterfaceAudience.LimitedPrivate; +import org.apache.hadoop.classification.InterfaceAudience.Private; import org.apache.hadoop.classification.InterfaceStability.Unstable; +import java.io.IOException; +import java.io.InputStream; +import java.io.OutputStream; + @LimitedPrivate({ "MapReduce", "YARN" }) @Unstable public interface RecordFactory { public T newRecordInstance(Class clazz); + + @Private + public void write(T record, OutputStream output) throws IOException; + + @Private + public T read(InputStream input, Class clazz) throws IOException; + } diff --git a/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-common/src/main/java/org/apache/hadoop/yarn/factories/impl/pb/RecordFactoryPBImpl.java b/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-common/src/main/java/org/apache/hadoop/yarn/factories/impl/pb/RecordFactoryPBImpl.java index 5e75b8d..ccdd72a 100644 --- a/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-common/src/main/java/org/apache/hadoop/yarn/factories/impl/pb/RecordFactoryPBImpl.java +++ b/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-common/src/main/java/org/apache/hadoop/yarn/factories/impl/pb/RecordFactoryPBImpl.java @@ -18,11 +18,16 @@ package org.apache.hadoop.yarn.factories.impl.pb; +import java.io.IOException; +import java.io.InputStream; +import java.io.OutputStream; import java.lang.reflect.Constructor; +import java.lang.reflect.Method; import java.lang.reflect.InvocationTargetException; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ConcurrentMap; +import com.google.protobuf.Message; import org.apache.hadoop.classification.InterfaceAudience.Private; import org.apache.hadoop.conf.Configuration; import org.apache.hadoop.yarn.exceptions.YarnRuntimeException; @@ -46,18 +51,23 @@ public static RecordFactory get() { } @SuppressWarnings("unchecked") + private Class getPBImplClass(Class clazz) { + Class pbClazz = null; + try { + return (Class) localConf.getClassByName(getPBImplClassName(clazz)); + } catch (ClassNotFoundException e) { + throw new YarnRuntimeException("Failed to load class: [" + + getPBImplClassName(clazz) + "]", e); + } + } + + @SuppressWarnings("unchecked") @Override public T newRecordInstance(Class clazz) { Constructor constructor = cache.get(clazz); if (constructor == null) { - Class pbClazz = null; - try { - pbClazz = localConf.getClassByName(getPBImplClassName(clazz)); - } catch (ClassNotFoundException e) { - throw new YarnRuntimeException("Failed to load class: [" - + getPBImplClassName(clazz) + "]", e); - } + Class pbClazz = getPBImplClass(clazz); try { constructor = pbClazz.getConstructor(); constructor.setAccessible(true); @@ -78,6 +88,44 @@ public static RecordFactory get() { } } + @SuppressWarnings("unchecked") + private Message getProto(Class klass) { + try { + Method method = klass.getMethod("getProto"); + Class messageClass = method.getReturnType(); + method = messageClass.getMethod("newBuilder"); + Object builder = method.invoke(null); + method = builder.getClass().getMethod("build"); + return (Message) method.invoke(builder); + } catch (Exception ex) { + throw new RuntimeException(ex); + } + } + + @Override + public void write(T record, OutputStream output) throws IOException { + Message message = getProto(record.getClass()); + message.writeTo(output); + } + + @Override + @SuppressWarnings("unchecked") + public T read(InputStream input, Class clazz) throws IOException { + try { + Class pbImplClass =getPBImplClass(clazz); + Message message = getProto(pbImplClass); + Message.Builder builder = message.newBuilderForType(); + Message proto = builder.mergeFrom(input).build(); + Class protoClass = proto.getClass(); + Constructor recordConstructor = pbImplClass.getConstructor(protoClass); + return (T) recordConstructor.newInstance(proto); + } catch (IOException ex) { + throw ex; + } catch (Exception ex) { + throw new RuntimeException(ex); + } + } + private String getPBImplClassName(Class clazz) { String srcPackagePart = getPackageName(clazz); String srcClassName = getClassName(clazz); diff --git a/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-common/src/test/java/org/apache/hadoop/yarn/TestRecordFactory.java b/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-common/src/test/java/org/apache/hadoop/yarn/TestRecordFactory.java index e9dcc38..0b9791b 100644 --- a/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-common/src/test/java/org/apache/hadoop/yarn/TestRecordFactory.java +++ b/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-common/src/test/java/org/apache/hadoop/yarn/TestRecordFactory.java @@ -20,7 +20,9 @@ import junit.framework.Assert; +import org.apache.hadoop.yarn.api.records.impl.pb.ApplicationIdPBImpl; import org.apache.hadoop.yarn.exceptions.YarnRuntimeException; +import org.apache.hadoop.yarn.api.records.ApplicationId; import org.apache.hadoop.yarn.factories.RecordFactory; import org.apache.hadoop.yarn.factories.impl.pb.RecordFactoryPBImpl; import org.apache.hadoop.yarn.api.protocolrecords.AllocateRequest; @@ -29,6 +31,11 @@ import org.apache.hadoop.yarn.api.protocolrecords.impl.pb.AllocateResponsePBImpl; import org.junit.Test; +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; +import java.io.InputStream; +import java.lang.reflect.Method; + public class TestRecordFactory { @Test @@ -54,4 +61,26 @@ public void testPbRecordFactory() { } } + @Test + public void testSerDeser() throws Exception { + RecordFactory factory = RecordFactoryPBImpl.get(); + ApplicationId id1 = factory.newRecordInstance(ApplicationId.class); + + //trick to be able to set the id for testing + Method method = id1.getClass().getDeclaredMethod("setId", Integer.TYPE); + method.setAccessible(true); + method.invoke(id1, 12345); + + ByteArrayOutputStream baos = new ByteArrayOutputStream(); + factory.write(id1, baos); + baos.close(); + InputStream is = new ByteArrayInputStream(baos.toByteArray()); + ApplicationId id2 = factory.read(is, ApplicationId.class); + + //forcing id1 to build the proto, otherwise equals() will fail + ((ApplicationIdPBImpl)id1).getProto(); + + Assert.assertEquals(id1, id2); + } + }