Index: lucene/core/src/java/org/apache/lucene/search/ScoreCachingWrappingScorer.java IDEA additional info: Subsystem: com.intellij.openapi.diff.impl.patch.CharsetEP <+>UTF-8 =================================================================== --- lucene/core/src/java/org/apache/lucene/search/ScoreCachingWrappingScorer.java (revision 1413821) +++ lucene/core/src/java/org/apache/lucene/search/ScoreCachingWrappingScorer.java (revision ) @@ -37,7 +37,16 @@ private final Scorer scorer; private int curDoc = -1; private float curScore; - + + /** Wraps the provided scorer in ScoreCachingWrappingScorer if needed. */ + public static ScoreCachingWrappingScorer wrap(Scorer scorer) { + if (scorer instanceof ScoreCachingWrappingScorer) { + return (ScoreCachingWrappingScorer) scorer; + } else { + return new ScoreCachingWrappingScorer(scorer); + } + } + /** Creates a new instance by wrapping the given scorer. */ public ScoreCachingWrappingScorer(Scorer scorer) { super(scorer.weight); Index: lucene/core/src/java/org/apache/lucene/search/FieldComparator.java IDEA additional info: Subsystem: com.intellij.openapi.diff.impl.patch.CharsetEP <+>UTF-8 =================================================================== --- lucene/core/src/java/org/apache/lucene/search/FieldComparator.java (revision 1413821) +++ lucene/core/src/java/org/apache/lucene/search/FieldComparator.java (revision ) @@ -17,10 +17,7 @@ * limitations under the License. */ -import java.io.IOException; -import java.util.Comparator; - -import org.apache.lucene.index.AtomicReader; // javadocs +import org.apache.lucene.index.AtomicReader; import org.apache.lucene.index.AtomicReaderContext; import org.apache.lucene.index.DocValues; import org.apache.lucene.search.FieldCache.ByteParser; @@ -35,6 +32,9 @@ import org.apache.lucene.util.BytesRef; import org.apache.lucene.util.packed.PackedInts; +import java.io.IOException; +import java.util.Comparator; + /** * Expert: a FieldComparator compares hits so as to determine their * sort order when collecting the top results with {@link @@ -964,14 +964,12 @@ @Override public void setScorer(Scorer scorer) { + // note: if scorer isn't already wrapped in a cache by now then it's + // probably too late. // wrap with a ScoreCachingWrappingScorer so that successive calls to // score() will not incur score computation over and // over again. - if (!(scorer instanceof ScoreCachingWrappingScorer)) { - this.scorer = new ScoreCachingWrappingScorer(scorer); - } else { - this.scorer = scorer; - } + this.scorer = ScoreCachingWrappingScorer.wrap(scorer); } @Override Index: lucene/queries/src/test/org/apache/lucene/queries/function/TestFunctionQuerySort.java IDEA additional info: Subsystem: com.intellij.openapi.diff.impl.patch.CharsetEP <+>UTF-8 =================================================================== --- lucene/queries/src/test/org/apache/lucene/queries/function/TestFunctionQuerySort.java (revision 1413821) +++ lucene/queries/src/test/org/apache/lucene/queries/function/TestFunctionQuerySort.java (revision ) @@ -17,14 +17,16 @@ * limitations under the License. */ -import java.io.IOException; import org.apache.lucene.document.Document; import org.apache.lucene.document.Field; import org.apache.lucene.document.StringField; +import org.apache.lucene.index.AtomicReaderContext; import org.apache.lucene.index.IndexReader; import org.apache.lucene.index.IndexWriterConfig; import org.apache.lucene.index.RandomIndexWriter; +import org.apache.lucene.queries.function.docvalues.IntDocValues; import org.apache.lucene.queries.function.valuesource.IntFieldSource; +import org.apache.lucene.queries.function.valuesource.SingleFunction; import org.apache.lucene.search.FieldDoc; import org.apache.lucene.search.IndexSearcher; import org.apache.lucene.search.MatchAllDocsQuery; @@ -33,12 +35,19 @@ import org.apache.lucene.search.Sort; import org.apache.lucene.search.SortField; import org.apache.lucene.search.TopDocs; +import org.apache.lucene.search.TopFieldCollector; import org.apache.lucene.store.Directory; import org.apache.lucene.util.LuceneTestCase; +import java.io.IOException; +import java.util.Map; +import java.util.concurrent.atomic.AtomicInteger; + /** Test that functionquery's getSortField() actually works */ public class TestFunctionQuerySort extends LuceneTestCase { + public static final int NUM_VALS = 5; + public void testSearchAfterWhenSortingByFunctionValues() throws IOException { Directory dir = newDirectory(); IndexWriterConfig iwc = newIndexWriterConfig(TEST_VERSION_CURRENT, null); @@ -47,10 +56,11 @@ Document doc = new Document(); Field field = new StringField("value", "", Field.Store.YES); + Field constField = new StringField("const", "SAME", Field.Store.YES); doc.add(field); + doc.add(constField); // Save docs unsorted (decreasing value n, n-1, ...) - final int NUM_VALS = 5; for (int val = NUM_VALS; val > 0; val--) { field.setStringValue(Integer.toString(val)); writer.addDocument(doc); @@ -61,8 +71,26 @@ writer.close(); IndexSearcher searcher = new IndexSearcher(reader); + // Lets do a quick search test using FunctionQuery sort score desc + { + MonitorValueSource src = new MonitorValueSource(new IntFieldSource("value")); + Query q = new FunctionQuery(src); + Sort orderBy = new Sort(); + //the const field has no real effect other than to make the sort situation + // not as simple + SortField sf = new SortField(null, SortField.Type.SCORE, true); + SortField sfConst = new SortField("const", SortField.Type.STRING, false); + orderBy.setSort(sfConst, sf); + TopFieldCollector collector = TopFieldCollector.create( + orderBy.rewrite(searcher), reader.maxDoc(), false, true, true, true + ); + searcher.search(q, collector); + verifySortHits(reader, src, collector.topDocs()); + } + //good; continue testing... + // Get ValueSource from FieldCache - IntFieldSource src = new IntFieldSource("value"); + MonitorValueSource src = new MonitorValueSource(new IntFieldSource("value")); // ...and make it a sort criterion SortField sf = src.getSortField(false).rewrite(searcher); Sort orderBy = new Sort(sf); @@ -70,13 +98,7 @@ // Get hits sorted by our FunctionValues (ascending values) Query q = new MatchAllDocsQuery(); TopDocs hits = searcher.search(q, Integer.MAX_VALUE, orderBy); - assertEquals(NUM_VALS, hits.scoreDocs.length); - // Verify that sorting works in general - int i = 0; - for (ScoreDoc hit : hits.scoreDocs) { - int valueFromDoc = Integer.parseInt(reader.document(hit.doc).get("value")); - assertEquals(++i, valueFromDoc); - } + verifySortHits(reader, src, hits); // Now get hits after hit #2 using IS.searchAfter() int afterIdx = 1; @@ -95,5 +117,49 @@ } reader.close(); dir.close(); + } + + private void verifySortHits(IndexReader reader, MonitorValueSource src, TopDocs hits) throws IOException { + assertEquals(NUM_VALS, hits.scoreDocs.length); + assertEquals(NUM_VALS, src.callCount.get()); + // Verify that sorting works in general + int i = 0; + for (ScoreDoc hit : hits.scoreDocs) { + int valueFromDoc = Integer.parseInt(reader.document(hit.doc).get("value")); + assertEquals(++i, valueFromDoc); + } + } + + /** Wraps a ValueSource to increment a counter each time a value is retrieved. */ + static class MonitorValueSource extends SingleFunction { + + final AtomicInteger callCount = new AtomicInteger(); + + public MonitorValueSource(ValueSource source) { + super(source); + } + + @Override + protected String name() { + return "monitor"; + } + + @Override + public FunctionValues getValues(Map context, AtomicReaderContext readerContext) throws IOException { + final FunctionValues vals = source.getValues(context, readerContext); + return new IntDocValues(this) { + + @Override + public int intVal(int doc) { + callCount.incrementAndGet(); + return vals.intVal(doc); + } + + @Override + public String toString(int doc) { + return name() + '(' + vals.toString(doc) + ')'; + } + }; + } } } Index: lucene/core/src/java/org/apache/lucene/search/TopFieldCollector.java IDEA additional info: Subsystem: com.intellij.openapi.diff.impl.patch.CharsetEP <+>UTF-8 =================================================================== --- lucene/core/src/java/org/apache/lucene/search/TopFieldCollector.java (revision 1413821) +++ lucene/core/src/java/org/apache/lucene/search/TopFieldCollector.java (revision ) @@ -17,12 +17,12 @@ * limitations under the License. */ -import java.io.IOException; - import org.apache.lucene.index.AtomicReaderContext; import org.apache.lucene.search.FieldValueHitQueue.Entry; import org.apache.lucene.util.PriorityQueue; +import java.io.IOException; + /** * A {@link Collector} that sorts by {@link SortField} using * {@link FieldComparator}s. @@ -33,7 +33,45 @@ * @lucene.experimental */ public abstract class TopFieldCollector extends TopDocsCollector { - + + Scorer scorer; + + @Override + public void setScorer(Scorer scorer) throws IOException { + //We should wrap the scorer in a caching scorer if it isn't already + // wrapped and one of the comparators is a relevancy one. The relevancy + // one uses the scorer. + boolean doCache = false; + FieldComparator[] fieldComparators = getFieldComparators(); + if (!(scorer instanceof ScoreCachingWrappingScorer)) { + for (FieldComparator fieldComparator : fieldComparators) { + if (fieldComparator instanceof FieldComparator.RelevanceComparator) { + doCache = true; + break; + } + } + } + if (doCache) + scorer = new ScoreCachingWrappingScorer(scorer); + + this.scorer = scorer; + // set the scorer on all comparators + for (FieldComparator comparator : fieldComparators) { + comparator.setScorer(scorer); + } + } + + protected Scorer maybeCacheWrapScorer(Scorer scorer) { + if (scorer instanceof ScoreCachingWrappingScorer) + return scorer; + //if any comparator is a RelevanceComparator, then we should wrap the scorer + for (FieldComparator comparator : getFieldComparators()) { + if (comparator instanceof FieldComparator.RelevanceComparator) + return new ScoreCachingWrappingScorer(scorer); + } + return scorer;//don't wrap + } + // TODO: one optimization we could do is to pre-fill // the queue with sentinel value that guaranteed to // always compare lower than a real hit; this would @@ -57,7 +95,12 @@ comparator = queue.getComparators()[0]; reverseMul = queue.getReverseMul()[0]; } - + + @Override + protected FieldComparator[] getFieldComparators() { + return new FieldComparator[]{comparator}; + } + final void updateBottom(int doc) { // bottom.score is already set to Float.NaN in add(). bottom.doc = docBase + doc; @@ -98,13 +141,8 @@ comparator = queue.firstComparator; } - @Override - public void setScorer(Scorer scorer) throws IOException { - comparator.setScorer(scorer); - } - + } + - } - /* * Implements a TopFieldCollector over one SortField criteria, without * tracking document scores and maxScore, and assumes out of orderness in doc @@ -158,8 +196,6 @@ private static class OneComparatorScoringNoMaxScoreCollector extends OneComparatorNonScoringCollector { - Scorer scorer; - public OneComparatorScoringNoMaxScoreCollector(FieldValueHitQueue queue, int numHits, boolean fillFields) { super(queue, numHits, fillFields); @@ -204,14 +240,8 @@ } } - @Override - public void setScorer(Scorer scorer) throws IOException { - this.scorer = scorer; - comparator.setScorer(scorer); - } - + } + - } - /* * Implements a TopFieldCollector over one SortField criteria, while tracking * document scores but no maxScore, and assumes out of orderness in doc Ids @@ -317,13 +347,8 @@ } } - + - @Override - public void setScorer(Scorer scorer) throws IOException { - this.scorer = scorer; - super.setScorer(scorer); - } + } - } /* * Implements a TopFieldCollector over one SortField criteria, with tracking @@ -391,7 +416,12 @@ comparators = queue.getComparators(); reverseMul = queue.getReverseMul(); } - + + @Override + protected FieldComparator[] getFieldComparators() { + return comparators; + } + final void updateBottom(int doc) { // bottom.score is already set to Float.NaN in add(). bottom.doc = docBase + doc; @@ -453,14 +483,7 @@ } } - @Override - public void setScorer(Scorer scorer) throws IOException { - // set the scorer on all comparators - for (int i = 0; i < comparators.length; i++) { - comparators[i].setScorer(scorer); - } + } - } - } /* * Implements a TopFieldCollector over multiple SortField criteria, without @@ -536,9 +559,7 @@ * tracking document scores and maxScore. */ private static class MultiComparatorScoringMaxScoreCollector extends MultiComparatorNonScoringCollector { - + - Scorer scorer; - public MultiComparatorScoringMaxScoreCollector(FieldValueHitQueue queue, int numHits, boolean fillFields) { super(queue, numHits, fillFields); @@ -603,12 +624,7 @@ } } - @Override - public void setScorer(Scorer scorer) throws IOException { - this.scorer = scorer; - super.setScorer(scorer); - } + } - } /* * Implements a TopFieldCollector over multiple SortField criteria, with @@ -754,12 +770,7 @@ } } - @Override - public void setScorer(Scorer scorer) throws IOException { - this.scorer = scorer; - super.setScorer(scorer); - } + } - } /* * Implements a TopFieldCollector over multiple SortField criteria, with @@ -829,12 +840,6 @@ } @Override - public void setScorer(Scorer scorer) throws IOException { - this.scorer = scorer; - super.setScorer(scorer); - } - - @Override public boolean acceptsDocsOutOfOrder() { return true; } @@ -846,7 +851,6 @@ */ private final static class PagingFieldCollector extends TopFieldCollector { - Scorer scorer; int collectedHits; final FieldComparator[] comparators; final int[] reverseMul; @@ -870,7 +874,12 @@ // Must set maxScore to NEG_INF, or otherwise Math.max always returns NaN. maxScore = Float.NEGATIVE_INFINITY; } - + + @Override + protected FieldComparator[] getFieldComparators() { + return comparators; + } + void updateBottom(int doc, float score) { bottom.doc = docBase + doc; bottom.score = score; @@ -976,16 +985,8 @@ } } } - + @Override - public void setScorer(Scorer scorer) { - this.scorer = scorer; - for (int i = 0; i < comparators.length; i++) { - comparators[i].setScorer(scorer); - } - } - - @Override public boolean acceptsDocsOutOfOrder() { return true; } @@ -1025,6 +1026,12 @@ this.numHits = numHits; this.fillFields = fillFields; } + + /** + * Provides a collection of the comparators. To be treated as read-only. Not + * null. + */ + protected abstract FieldComparator[] getFieldComparators(); /** * Creates a new {@link TopFieldCollector} from the given