diff --git ql/src/java/org/apache/hadoop/hive/ql/exec/spark/SparkPlanGenerator.java ql/src/java/org/apache/hadoop/hive/ql/exec/spark/SparkPlanGenerator.java index 45eff67..1983ed3 100644 --- ql/src/java/org/apache/hadoop/hive/ql/exec/spark/SparkPlanGenerator.java +++ ql/src/java/org/apache/hadoop/hive/ql/exec/spark/SparkPlanGenerator.java @@ -23,13 +23,17 @@ import java.util.List; import java.util.Set; +import org.apache.commons.lang.StringUtils; +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; import org.apache.hadoop.fs.Path; import org.apache.hadoop.hive.conf.HiveConf; import org.apache.hadoop.hive.ql.Context; import org.apache.hadoop.hive.ql.ErrorMsg; import org.apache.hadoop.hive.ql.exec.Utilities; import org.apache.hadoop.hive.ql.exec.mr.ExecMapper; -import org.apache.hadoop.hive.ql.io.HiveInputFormat; +import org.apache.hadoop.hive.ql.io.BucketizedHiveInputFormat; +import org.apache.hadoop.hive.ql.io.CombineHiveInputFormat; import org.apache.hadoop.hive.ql.metadata.HiveException; import org.apache.hadoop.hive.ql.plan.BaseWork; import org.apache.hadoop.hive.ql.plan.MapWork; @@ -38,6 +42,7 @@ import org.apache.hadoop.hive.ql.plan.SparkWork; import org.apache.hadoop.hive.ql.stats.StatsFactory; import org.apache.hadoop.hive.ql.stats.StatsPublisher; +import org.apache.hadoop.hive.shims.ShimLoader; import org.apache.hadoop.io.BytesWritable; import org.apache.hadoop.io.Writable; import org.apache.hadoop.io.WritableComparable; @@ -46,6 +51,8 @@ import org.apache.spark.api.java.JavaSparkContext; public class SparkPlanGenerator { + private static transient final Log LOG = LogFactory.getLog(SparkPlanGenerator.class); + private JavaSparkContext sc; private JobConf jobConf; private Context context; @@ -87,13 +94,37 @@ public SparkPlan generate(SparkWork sparkWork) throws Exception { List inputPaths = Utilities.getInputPaths(jobConf, mapWork, scratchDir, context, false); Utilities.setInputPaths(jobConf, inputPaths); Utilities.setMapWork(jobConf, mapWork, scratchDir, true); - Class ifClass = HiveInputFormat.class; + Class ifClass = getInputFormat(mapWork); // The mapper class is expected by the HiveInputFormat. jobConf.set("mapred.mapper.class", ExecMapper.class.getName()); return sc.hadoopRDD(jobConf, ifClass, WritableComparable.class, Writable.class); } + private Class getInputFormat(MapWork mWork) { + if (mWork.getInputformat() != null) { + HiveConf.setVar(jobConf, HiveConf.ConfVars.HIVEINPUTFORMAT, mWork.getInputformat()); + } + String inpFormat = HiveConf.getVar(jobConf, HiveConf.ConfVars.HIVEINPUTFORMAT); + if ((inpFormat == null) || (!StringUtils.isNotBlank(inpFormat))) { + inpFormat = ShimLoader.getHadoopShims().getInputFormatClassName(); + } + + if (mWork.isUseBucketizedHiveInputFormat()) { + inpFormat = BucketizedHiveInputFormat.class.getName(); + } + + // Set default input format to CombineHiveInputFormat, in case failed to load specified class. + Class inputFormatClass = CombineHiveInputFormat.class; + try { + inputFormatClass = Class.forName(inpFormat); + } catch (ClassNotFoundException e) { + LOG.warn("Failed to load specified input format class:" + inpFormat, e); + } + + return inputFormatClass; + } + private SparkTran generate(BaseWork bw) throws IOException, HiveException { // initialize stats publisher if necessary if (bw.isGatheringStats()) {