diff --git ql/src/java/org/apache/hadoop/hive/ql/exec/StatsTask.java ql/src/java/org/apache/hadoop/hive/ql/exec/StatsTask.java index 142af10..37343df 100644 --- ql/src/java/org/apache/hadoop/hive/ql/exec/StatsTask.java +++ ql/src/java/org/apache/hadoop/hive/ql/exec/StatsTask.java @@ -34,6 +34,7 @@ import org.apache.hadoop.hive.metastore.Warehouse; import org.apache.hadoop.hive.ql.DriverContext; import org.apache.hadoop.hive.ql.ErrorMsg; +import org.apache.hadoop.hive.ql.exec.mr.MapRedTask; import org.apache.hadoop.hive.ql.metadata.HiveException; import org.apache.hadoop.hive.ql.metadata.Partition; import org.apache.hadoop.hive.ql.metadata.Table; @@ -191,7 +192,11 @@ private int aggregateStats() { if (factory != null) { statsAggregator = factory.getStatsAggregator(); // manufacture a StatsAggregator - if (!statsAggregator.connect(conf, getWork().getSourceTask())) { + MapRedTask sourceTask = getWork().getSourceTask(); + if (sourceTask == null) { + throw new HiveException("SourceTask for StatsAggregator should not be null"); + } + if (!statsAggregator.connect(conf, sourceTask)) { throw new HiveException("StatsAggregator connect failed " + statsImplementationClass); } } diff --git ql/src/java/org/apache/hadoop/hive/ql/optimizer/GenMRTableScan1.java ql/src/java/org/apache/hadoop/hive/ql/optimizer/GenMRTableScan1.java index 0268f98..8d3dc56 100644 --- ql/src/java/org/apache/hadoop/hive/ql/optimizer/GenMRTableScan1.java +++ ql/src/java/org/apache/hadoop/hive/ql/optimizer/GenMRTableScan1.java @@ -35,6 +35,7 @@ import org.apache.hadoop.hive.ql.exec.TableScanOperator; import org.apache.hadoop.hive.ql.exec.Task; import org.apache.hadoop.hive.ql.exec.TaskFactory; +import org.apache.hadoop.hive.ql.exec.mr.MapRedTask; import org.apache.hadoop.hive.ql.io.rcfile.stats.PartialScanWork; import org.apache.hadoop.hive.ql.lib.Node; import org.apache.hadoop.hive.ql.lib.NodeProcessor; @@ -74,7 +75,7 @@ public Object process(Node nd, Stack stack, NodeProcessorCtx opProcCtx, // create a dummy MapReduce task MapredWork currWork = GenMapRedUtils.getMapRedWork(parseCtx); - Task currTask = TaskFactory.get(currWork, parseCtx.getConf()); + MapRedTask currTask = (MapRedTask) TaskFactory.get(currWork, parseCtx.getConf()); Operator currTopOp = op; ctx.setCurrTask(currTask); ctx.setCurrTopOp(currTopOp); @@ -95,6 +96,7 @@ public Object process(Node nd, Stack stack, NodeProcessorCtx opProcCtx, StatsWork statsWork = new StatsWork(parseCtx.getQB().getParseInfo().getTableSpec()); statsWork.setAggKey(op.getConf().getStatsAggPrefix()); + statsWork.setSourceTask(currTask); statsWork.setStatsReliable( parseCtx.getConf().getBoolVar(HiveConf.ConfVars.HIVE_STATS_RELIABLE)); Task statsTask = TaskFactory.get(statsWork, parseCtx.getConf());