diff --git hbase-server/src/main/java/org/apache/hadoop/hbase/mapreduce/TableInputFormatBase.java hbase-server/src/main/java/org/apache/hadoop/hbase/mapreduce/TableInputFormatBase.java index 325a781..4123467 100644 --- hbase-server/src/main/java/org/apache/hadoop/hbase/mapreduce/TableInputFormatBase.java +++ hbase-server/src/main/java/org/apache/hadoop/hbase/mapreduce/TableInputFormatBase.java @@ -18,6 +18,7 @@ */ package org.apache.hadoop.hbase.mapreduce; +import java.io.Closeable; import java.io.IOException; import java.net.InetAddress; import java.net.InetSocketAddress; @@ -109,6 +110,8 @@ extends InputFormat { private HashMap reverseDNSCacheMap = new HashMap(); + private Connection connection; + /** * Builds a {@link TableRecordReader}. If no {@link TableRecordReader} was provided, uses * the default. @@ -132,19 +135,55 @@ extends InputFormat { } TableSplit tSplit = (TableSplit) split; LOG.info("Input split length: " + StringUtils.humanReadableInt(tSplit.getLength()) + " bytes."); - TableRecordReader trr = this.tableRecordReader; - // if no table record reader was provided use default - if (trr == null) { - trr = new TableRecordReader(); - } + final TableRecordReader trr = + this.tableRecordReader != null ? this.tableRecordReader : new TableRecordReader(); Scan sc = new Scan(this.scan); sc.setStartRow(tSplit.getStartRow()); sc.setStopRow(tSplit.getEndRow()); trr.setScan(sc); trr.setTable(table); - return trr; + return new RecordReader() { + + @Override + public void close() throws IOException { + trr.close(); + close(admin, table, regionLocator, connection); + } + + private void close(Closeable... closables) throws IOException { + for (Closeable c : closables) { + if(c != null) { c.close(); } + } + } + + @Override + public ImmutableBytesWritable getCurrentKey() throws IOException, InterruptedException { + return trr.getCurrentKey(); + } + + @Override + public Result getCurrentValue() throws IOException, InterruptedException { + return trr.getCurrentValue(); + } + + @Override + public float getProgress() throws IOException, InterruptedException { + return trr.getProgress(); + } + + @Override + public void initialize(InputSplit inputsplit, TaskAttemptContext context) throws IOException, + InterruptedException { + trr.initialize(inputsplit, context); + } + + @Override + public boolean nextKeyValue() throws IOException, InterruptedException { + return trr.nextKeyValue(); + } + }; } - + protected Pair getStartEndKeys() throws IOException { return regionLocator.getStartEndKeys(); } @@ -331,6 +370,7 @@ extends InputFormat { this.table = connection.getTable(tableName); this.regionLocator = connection.getRegionLocator(tableName); this.admin = connection.getAdmin(); + this.connection = connection; } /**