diff --git a/ql/src/java/org/apache/hadoop/hive/ql/exec/mr/MapRedTask.java b/ql/src/java/org/apache/hadoop/hive/ql/exec/mr/MapRedTask.java index 8266906c338c88fa35bc3a3b68e45c1ab6c2243c..0f594a16252503bfca9fbddfc7c6551a5b91c89f 100644 --- a/ql/src/java/org/apache/hadoop/hive/ql/exec/mr/MapRedTask.java +++ b/ql/src/java/org/apache/hadoop/hive/ql/exec/mr/MapRedTask.java @@ -49,10 +49,16 @@ import org.apache.hadoop.hive.ql.session.SessionState; import org.apache.hadoop.hive.ql.session.SessionState.ResourceType; import org.apache.hadoop.hive.shims.ShimLoader; +import org.apache.hadoop.hive.shims.Utils; import org.apache.hive.common.util.StreamPrinter; import org.apache.hadoop.mapred.RunningJob; + +import com.google.common.annotations.VisibleForTesting; + import org.json.JSONException; +import static org.apache.hadoop.security.UserGroupInformation.AuthenticationMethod.PROXY; + /** * Extension of ExecDriver: * - can optionally spawn a map-reduce task from a separate jvm @@ -71,6 +77,7 @@ static final String HIVE_MAIN_CLIENT_DEBUG_OPTS = "HIVE_MAIN_CLIENT_DEBUG_OPTS"; static final String HIVE_CHILD_CLIENT_DEBUG_OPTS = "HIVE_CHILD_CLIENT_DEBUG_OPTS"; static final String[] HIVE_SYS_PROP = {"build.dir", "build.dir.hive", "hive.query.id"}; + static final String HADOOP_PROXY_USER = "HADOOP_PROXY_USER"; private transient ContentSummary inputSummary = null; private transient boolean runningViaChild = false; @@ -267,6 +274,10 @@ public int execute(DriverContext driverContext) { configureDebugVariablesForChildJVM(variables); } + if (PROXY == Utils.getUGI().getAuthenticationMethod()) { + variables.put(HADOOP_PROXY_USER, Utils.getUGI().getShortUserName()); + } + env = new String[variables.size()]; int pos = 0; for (Map.Entry entry : variables.entrySet()) { @@ -275,7 +286,7 @@ public int execute(DriverContext driverContext) { env[pos++] = name + "=" + value; } // Run ExecDriver in another JVM - executor = Runtime.getRuntime().exec(cmdLine, env, new File(workDir)); + executor = spawn(cmdLine, workDir, env); CachingPrintStream errPrintStream = new CachingPrintStream(SessionState.getConsole().getChildErrStream()); @@ -323,6 +334,11 @@ public int execute(DriverContext driverContext) { } } + @VisibleForTesting + Process spawn(String cmdLine, String workDir, String[] env) throws IOException { + return Runtime.getRuntime().exec(cmdLine, env, new File(workDir)); + } + static void configureDebugVariablesForChildJVM(Map environmentVariables) { // this method contains various asserts to warn if environment variables are in a buggy state assert environmentVariables.containsKey(HADOOP_CLIENT_OPTS) diff --git a/ql/src/test/org/apache/hadoop/hive/ql/exec/mr/TestMapRedTask.java b/ql/src/test/org/apache/hadoop/hive/ql/exec/mr/TestMapRedTask.java index 0693d2458b8fd6c4df62dee4c02d4f6a04199c70..f1041ce933893710af37153bfba3b0e43bcb4e33 100644 --- a/ql/src/test/org/apache/hadoop/hive/ql/exec/mr/TestMapRedTask.java +++ b/ql/src/test/org/apache/hadoop/hive/ql/exec/mr/TestMapRedTask.java @@ -17,16 +17,32 @@ */ package org.apache.hadoop.hive.ql.exec.mr; +import static org.apache.hadoop.security.UserGroupInformation.AuthenticationMethod.PROXY; import static org.mockito.Mockito.never; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; import java.io.IOException; +import java.util.Arrays; +import javax.security.auth.login.LoginException; + +import org.apache.hadoop.fs.Path; import org.apache.hadoop.hive.common.metrics.common.Metrics; import org.apache.hadoop.hive.common.metrics.common.MetricsConstant; +import org.apache.hadoop.hive.conf.HiveConf; +import org.apache.hadoop.hive.ql.Context; +import org.apache.hadoop.hive.ql.DriverContext; +import org.apache.hadoop.hive.ql.QueryState; import org.apache.hadoop.hive.ql.exec.spark.SparkTask; +import org.apache.hadoop.hive.ql.plan.MapWork; +import org.apache.hadoop.hive.ql.plan.MapredWork; +import org.apache.hadoop.hive.shims.Utils; + +import org.junit.Assert; import org.junit.Test; +import org.mockito.ArgumentCaptor; import org.mockito.Mockito; public class TestMapRedTask { @@ -44,4 +60,37 @@ public void mrTask_updates_Metrics() throws IOException { verify(mockMetrics, never()).incrementCounter(MetricsConstant.HIVE_SPARK_TASKS); } + @Test + public void mrTask_auth_method_Proxy() throws IOException, LoginException { + Utils.getUGI().setAuthenticationMethod(PROXY); + + Context ctx = Mockito.mock(Context.class); + when(ctx.getLocalTmpPath()).thenReturn(new Path(System.getProperty("java.io.tmpdir"))); + + DriverContext dctx = new DriverContext(ctx); + + QueryState queryState = new QueryState.Builder().build(); + HiveConf conf= queryState.getConf(); + conf.setBoolVar(HiveConf.ConfVars.SUBMITVIACHILD, true); + + MapredWork mrWork = new MapredWork(); + mrWork.setMapWork(Mockito.mock(MapWork.class)); + + MapRedTask mrTask = Mockito.spy(new MapRedTask()); + mrTask.setWork(mrWork); + + mrTask.initialize(queryState, null, dctx, null); + + mrTask.jobExecHelper = Mockito.mock(HadoopJobExecHelper.class);; + when(mrTask.jobExecHelper.progressLocal(Mockito.any(Process.class), Mockito.anyString())).thenReturn(0); + + mrTask.execute(dctx); + + ArgumentCaptor captor = ArgumentCaptor.forClass(String[].class); + verify(mrTask).spawn(Mockito.anyString(), Mockito.anyString(),captor.capture()); + + String expected = "HADOOP_PROXY_USER=" + Utils.getUGI().getUserName(); + Assert.assertTrue(Arrays.asList(captor.getValue()).contains(expected)); + } + }