Index: modules/grouping/src/java/org/apache/lucene/search/grouping/package.html =================================================================== --- modules/grouping/src/java/org/apache/lucene/search/grouping/package.html (revision 1145133) +++ modules/grouping/src/java/org/apache/lucene/search/grouping/package.html (revision ) @@ -164,5 +164,20 @@ have to separately retrieve it (for example using stored fields, FieldCache, etc.). +

Another collector is the TermAllGroupHeadsCollector that can be used to retrieve all most relevant + documents per group. Also known as group heads. This can be useful in situations when one wants to compute grouping + based facets / statistics on the complete query result. The collector can be executed during the first or second + phase.

+ +
+  AbstractAllGroupHeadsCollector c = TermAllGroupHeadsCollector.create(groupField, sortWithinGroup);
+  s.search(new TermQuery(new Term("content", searchTerm)), c);
+  // Return all group heads as int array
+  int[] groupHeadsArray = c.retrieveGroupHeads()
+  // Return all group heads as OpenBitSet.
+  int maxDoc = s.maxDoc();
+  OpenBitSet groupHeadsBitSet = c.retrieveGroupHeads(maxDoc)
+
+ Index: lucene/CHANGES.txt =================================================================== --- lucene/CHANGES.txt (revision 1147580) +++ lucene/CHANGES.txt (revision ) @@ -540,6 +540,10 @@ AbstractField.setOmitTermFrequenciesAndPositions is deprecated, you should use DOCS_ONLY instead. (Robert Muir) +* LUCENE-3097: Added a new grouping collector that can be used to retrieve all most relevant + documents per group. This can be useful in situations when one wants to compute grouping + based facets / statistics on the complete query result. (Martijn van Groningen) + Optimizations * LUCENE-3201, LUCENE-3218: CompoundFileSystem code has been consolidated Index: modules/grouping/src/test/org/apache/lucene/search/grouping/TermAllGroupHeadsCollectorTest.java =================================================================== --- modules/grouping/src/test/org/apache/lucene/search/grouping/TermAllGroupHeadsCollectorTest.java (revision ) +++ modules/grouping/src/test/org/apache/lucene/search/grouping/TermAllGroupHeadsCollectorTest.java (revision ) @@ -0,0 +1,173 @@ +package org.apache.lucene.search.grouping; + +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import org.apache.lucene.analysis.MockAnalyzer; +import org.apache.lucene.document.Document; +import org.apache.lucene.document.Field; +import org.apache.lucene.index.RandomIndexWriter; +import org.apache.lucene.index.Term; +import org.apache.lucene.search.*; +import org.apache.lucene.store.Directory; +import org.apache.lucene.util.LuceneTestCase; +import org.apache.lucene.util.OpenBitSet; + +import java.io.IOException; + +public class TermAllGroupHeadsCollectorTest extends LuceneTestCase { + + public void testRetrieveGroupHeadsAsArrayAndOpenBitset() throws Exception { + final String groupField = "author"; + Directory dir = newDirectory(); + RandomIndexWriter w = new RandomIndexWriter( + random, + dir, + newIndexWriterConfig(TEST_VERSION_CURRENT, + new MockAnalyzer(random)).setMergePolicy(newLogMergePolicy())); + index(w, groupField); + IndexSearcher indexSearcher = new IndexSearcher(w.getReader()); + w.close(); + + Sort sortWithinGroup = new Sort(new SortField("id", SortField.Type.INT, true)); + AbstractAllGroupHeadsCollector c1 = TermAllGroupHeadsCollector.create(groupField, sortWithinGroup); + indexSearcher.search(new TermQuery(new Term("content", "random")), c1); + assertTrue(arrayContains(new int[]{2, 3, 5, 6}, c1.retrieveGroupHeads())); + assertTrue(openBitSetContains(new int[]{2, 3, 5, 6}, c1.retrieveGroupHeads(6))); + + AbstractAllGroupHeadsCollector c2 = TermAllGroupHeadsCollector.create(groupField, sortWithinGroup); + indexSearcher.search(new TermQuery(new Term("content", "some")), c2); + assertTrue(arrayContains(new int[]{2, 3, 4}, c2.retrieveGroupHeads())); + assertTrue(openBitSetContains(new int[]{2, 3, 4}, c2.retrieveGroupHeads(6))); + + AbstractAllGroupHeadsCollector c3 = TermAllGroupHeadsCollector.create(groupField, sortWithinGroup); + indexSearcher.search(new TermQuery(new Term("content", "blob")), c3); + assertTrue(arrayContains(new int[]{1, 5}, c3.retrieveGroupHeads())); + assertTrue(openBitSetContains(new int[]{1, 5}, c3.retrieveGroupHeads(6))); + + // STRING sort type triggers different implementation + Sort sortWithinGroup2 = new Sort(new SortField("id", SortField.Type.STRING, true)); + AbstractAllGroupHeadsCollector c4 = TermAllGroupHeadsCollector.create(groupField, sortWithinGroup2); + indexSearcher.search(new TermQuery(new Term("content", "random")), c4); + assertTrue(arrayContains(new int[]{2, 3, 5, 6}, c4.retrieveGroupHeads())); + assertTrue(openBitSetContains(new int[]{2, 3, 5, 6}, c4.retrieveGroupHeads(6))); + + Sort sortWithinGroup3 = new Sort(new SortField("id", SortField.Type.STRING, false)); + AbstractAllGroupHeadsCollector c5 = TermAllGroupHeadsCollector.create(groupField, sortWithinGroup3); + indexSearcher.search(new TermQuery(new Term("content", "random")), c5); + // 6 b/c higher doc id wins, even if order of field is in not in reverse. + assertTrue(arrayContains(new int[]{0, 3, 4, 6}, c5.retrieveGroupHeads())); + assertTrue(openBitSetContains(new int[]{0, 3, 4, 6}, c5.retrieveGroupHeads(6))); + + indexSearcher.getIndexReader().close(); + dir.close(); + } + + private boolean arrayContains(int[] expected, int[] actual) { + if (expected.length != actual.length) { + return false; + } + + for (int e : expected) { + boolean found = false; + for (int a : actual) { + if (e == a) { + found = true; + } + } + + if (!found) { + return false; + } + } + + return true; + } + + private boolean openBitSetContains(int[] expectedDocs, OpenBitSet actual) throws IOException { + if (expectedDocs.length != actual.cardinality()) { + return false; + } + + OpenBitSet expected = new OpenBitSet(expectedDocs.length); + for (int expectedDoc : expectedDocs) { + expected.fastSet(expectedDoc); + } + + int docId; + DocIdSetIterator iterator = expected.iterator(); + while ((docId = iterator.nextDoc()) != DocIdSetIterator.NO_MORE_DOCS) { + if (!actual.fastGet(docId)) { + return false; + } + } + + return true; + } + + private void index(RandomIndexWriter w, String groupField) throws IOException { + // 0 + Document doc = new Document(); + doc.add(new Field(groupField, "author1", Field.Store.YES, Field.Index.ANALYZED)); + doc.add(new Field("content", "random text", Field.Store.YES, Field.Index.ANALYZED)); + doc.add(new Field("id", "1", Field.Store.YES, Field.Index.NOT_ANALYZED_NO_NORMS)); + w.addDocument(doc); + + // 1 + doc = new Document(); + doc.add(new Field(groupField, "author1", Field.Store.YES, Field.Index.ANALYZED)); + doc.add(new Field("content", "some more random text blob", Field.Store.YES, Field.Index.ANALYZED)); + doc.add(new Field("id", "2", Field.Store.YES, Field.Index.NOT_ANALYZED_NO_NORMS)); + w.addDocument(doc); + + // 2 + doc = new Document(); + doc.add(new Field(groupField, "author1", Field.Store.YES, Field.Index.ANALYZED)); + doc.add(new Field("content", "some more random textual data", Field.Store.YES, Field.Index.ANALYZED)); + doc.add(new Field("id", "3", Field.Store.YES, Field.Index.NOT_ANALYZED_NO_NORMS)); + w.addDocument(doc); + w.commit(); // To ensure a second segment + + // 3 + doc = new Document(); + doc.add(new Field(groupField, "author2", Field.Store.YES, Field.Index.ANALYZED)); + doc.add(new Field("content", "some random text", Field.Store.YES, Field.Index.ANALYZED)); + doc.add(new Field("id", "4", Field.Store.YES, Field.Index.NOT_ANALYZED_NO_NORMS)); + w.addDocument(doc); + + // 4 + doc = new Document(); + doc.add(new Field(groupField, "author3", Field.Store.YES, Field.Index.ANALYZED)); + doc.add(new Field("content", "some more random text", Field.Store.YES, Field.Index.ANALYZED)); + doc.add(new Field("id", "5", Field.Store.YES, Field.Index.NOT_ANALYZED_NO_NORMS)); + w.addDocument(doc); + + // 5 + doc = new Document(); + doc.add(new Field(groupField, "author3", Field.Store.YES, Field.Index.ANALYZED)); + doc.add(new Field("content", "random blob", Field.Store.YES, Field.Index.ANALYZED)); + doc.add(new Field("id", "6", Field.Store.YES, Field.Index.NOT_ANALYZED_NO_NORMS)); + w.addDocument(doc); + + // 6 -- no author field + doc = new Document(); + doc.add(new Field("content", "random word stuck in alot of other text", Field.Store.YES, Field.Index.ANALYZED)); + doc.add(new Field("id", "6", Field.Store.YES, Field.Index.NOT_ANALYZED_NO_NORMS)); + w.addDocument(doc); + } + +} Index: modules/grouping/src/java/org/apache/lucene/search/grouping/TermAllGroupHeadsCollector.java =================================================================== --- modules/grouping/src/java/org/apache/lucene/search/grouping/TermAllGroupHeadsCollector.java (revision ) +++ modules/grouping/src/java/org/apache/lucene/search/grouping/TermAllGroupHeadsCollector.java (revision ) @@ -0,0 +1,533 @@ +package org.apache.lucene.search.grouping; + +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import org.apache.lucene.index.IndexReader; +import org.apache.lucene.search.*; +import org.apache.lucene.util.BytesRef; + +import java.io.IOException; +import java.util.*; + +/** + * A base implementation of {@link AbstractAllGroupHeadsCollector} for retrieving the most relevant groups when grouping + * on a string based group field. More specifically this all concrete implementations of this base implementation + * use {@link org.apache.lucene.search.FieldCache.DocTermsIndex}. + * + * @lucene.experimental + */ +public abstract class TermAllGroupHeadsCollector extends AbstractAllGroupHeadsCollector { + + private static final int DEFAULT_INITIAL_SIZE = 128; + + final String groupField; + final BytesRef scratchBytesRef = new BytesRef(); + + FieldCache.DocTermsIndex groupIndex; + IndexReader.AtomicReaderContext readerContext; + + protected TermAllGroupHeadsCollector(String groupField, int numberOfSorts) { + super(numberOfSorts); + this.groupField = groupField; + } + + /** + * Creates an AbstractAllGroupHeadsCollector instance based on the supplied arguments. + * This factory method decides with implementation is best suited. + * + * @param groupField The field to group by + * @param sortWithinGroup The sort within each group + * @return an AbstractAllGroupHeadsCollector instance based on the supplied arguments + * @throws IOException If I/O related errors occur + */ + public static AbstractAllGroupHeadsCollector create(String groupField, Sort sortWithinGroup) throws IOException { + return create(groupField, sortWithinGroup, DEFAULT_INITIAL_SIZE); + } + + public static AbstractAllGroupHeadsCollector create(String groupField, Sort sortWithinGroup, int initialSize) throws IOException { + boolean sortAllScore = true; + boolean sortAllFieldValue = true; + + for (SortField sortField : sortWithinGroup.getSort()) { + if (sortField.getType() == SortField.Type.SCORE) { + sortAllFieldValue = false; + } else if (needGeneralImpl(sortField)) { + return new GeneralAllGroupHeadsCollector(groupField, sortWithinGroup); + } else { + sortAllScore = false; + } + } + + if (sortAllScore) { + return new ScoreAllGroupHeadsCollector(groupField, sortWithinGroup, initialSize); + } else if (sortAllFieldValue) { + return new OrdAllGroupHeadsCollector(groupField, sortWithinGroup, initialSize); + } else { + return new OrdScoreAllGroupHeadsCollector(groupField, sortWithinGroup, initialSize); + } + } + + // Returns when a sort field needs the general impl. + private static boolean needGeneralImpl(SortField sortField) { + SortField.Type sortType = sortField.getType(); + // Note (MvG): We can also make an optimized impl when sorting is SortField.DOC + return sortType != SortField.Type.STRING_VAL && sortType != SortField.Type.STRING && sortType != SortField.Type.SCORE; + } + + // A general impl that works for any group sort. + static class GeneralAllGroupHeadsCollector extends TermAllGroupHeadsCollector { + + private final Sort sortWithinGroup; + private final Map groups; + + private Scorer scorer; + + GeneralAllGroupHeadsCollector(String groupField, Sort sortWithinGroup) throws IOException { + super(groupField, sortWithinGroup.getSort().length); + this.sortWithinGroup = sortWithinGroup; + groups = new HashMap(); + + final SortField[] sortFields = sortWithinGroup.getSort(); + for (int i = 0; i < sortFields.length; i++) { + reversed[i] = sortFields[i].getReverse() ? -1 : 1; + } + } + + protected void retrieveGroupHeadAndAddIfNotExist(int doc) throws IOException { + final int ord = groupIndex.getOrd(doc); + final BytesRef groupValue = ord == 0 ? null : groupIndex.lookup(ord, scratchBytesRef); + GroupHead groupHead = groups.get(groupValue); + if (groupHead == null) { + groupHead = new GroupHead(groupValue, sortWithinGroup, doc); + groups.put(groupValue == null ? null : new BytesRef(groupValue), groupHead); + temporalResult.stop = true; + } else { + temporalResult.stop = false; + } + temporalResult.groupHead = groupHead; + } + + protected Collection getCollectedGroupHeads() { + return groups.values(); + } + + public void setNextReader(IndexReader.AtomicReaderContext context) throws IOException { + this.readerContext = context; + groupIndex = FieldCache.DEFAULT.getTermsIndex(context.reader, groupField); + + for (GroupHead groupHead : groups.values()) { + for (int i = 0; i < groupHead.comparators.length; i++) { + groupHead.comparators[i] = groupHead.comparators[i].setNextReader(context); + } + } + } + + public void setScorer(Scorer scorer) throws IOException { + this.scorer = scorer; + for (GroupHead groupHead : groups.values()) { + for (FieldComparator comparator : groupHead.comparators) { + comparator.setScorer(scorer); + } + } + } + + class GroupHead extends AbstractAllGroupHeadsCollector.GroupHead { + + final FieldComparator[] comparators; + + private GroupHead(BytesRef groupValue, Sort sort, int doc) throws IOException { + super(groupValue, doc + readerContext.docBase); + final SortField[] sortFields = sort.getSort(); + comparators = new FieldComparator[sortFields.length]; + for (int i = 0; i < sortFields.length; i++) { + comparators[i] = sortFields[i].getComparator(1, i).setNextReader(readerContext); + comparators[i].setScorer(scorer); + comparators[i].copy(0, doc); + comparators[i].setBottom(0); + } + } + + public int compare(int compIDX, int doc) throws IOException { + return comparators[compIDX].compareBottom(doc); + } + + public void updateDocHead(int doc) throws IOException { + for (FieldComparator comparator : comparators) { + comparator.copy(0, doc); + comparator.setBottom(0); + } + this.doc = doc + readerContext.docBase; + } + } + } + + + // AbstractAllGroupHeadsCollector optimized for ord fields and scores. + static class OrdScoreAllGroupHeadsCollector extends TermAllGroupHeadsCollector { + + private final SentinelIntSet ordSet; + private final List collectedGroups; + private final SortField[] fields; + private final boolean sortContainsScore; + + private FieldCache.DocTermsIndex[] sortsIndex; + private Scorer scorer; + private GroupHead[] segmentGroupHeads; + + OrdScoreAllGroupHeadsCollector(String groupField, Sort sortWithinGroup, int initialSize) { + super(groupField, sortWithinGroup.getSort().length); + ordSet = new SentinelIntSet(initialSize, -1); + collectedGroups = new ArrayList(initialSize); + + final SortField[] sortFields = sortWithinGroup.getSort(); + fields = new SortField[sortFields.length]; + sortsIndex = new FieldCache.DocTermsIndex[sortFields.length]; + boolean sortContainsScore = false; + for (int i = 0; i < sortFields.length; i++) { + reversed[i] = sortFields[i].getReverse() ? -1 : 1; + fields[i] = sortFields[i]; + if (sortFields[i].getType() == SortField.Type.SCORE) { + sortContainsScore = true; + } + } + this.sortContainsScore = sortContainsScore; + } + + protected Collection getCollectedGroupHeads() { + return collectedGroups; + } + + public void setScorer(Scorer scorer) throws IOException { + this.scorer = scorer; + } + + protected void retrieveGroupHeadAndAddIfNotExist(int doc) throws IOException { + int key = groupIndex.getOrd(doc); + GroupHead groupHead; + if (!ordSet.exists(key)) { + ordSet.put(key); + BytesRef term = key == 0 ? null : groupIndex.getTerm(doc, new BytesRef()); + groupHead = new GroupHead(doc, term); + collectedGroups.add(groupHead); + segmentGroupHeads[key] = groupHead; + temporalResult.stop = true; + } else { + temporalResult.stop = false; + groupHead = segmentGroupHeads[key]; + } + temporalResult.groupHead = groupHead; + } + + public void setNextReader(IndexReader.AtomicReaderContext context) throws IOException { + this.readerContext = context; + groupIndex = FieldCache.DEFAULT.getTermsIndex(context.reader, groupField); + for (int i = 0; i < fields.length; i++) { + sortsIndex[i] = FieldCache.DEFAULT.getTermsIndex(context.reader, fields[i].getField()); + } + + // Clear ordSet and fill it with previous encountered groups that can occur in the current segment. + ordSet.clear(); + segmentGroupHeads = new GroupHead[groupIndex.numOrd()]; + for (GroupHead collectedGroup : collectedGroups) { + int ord = groupIndex.binarySearchLookup(collectedGroup.groupValue, scratchBytesRef); + if (ord >= 0) { + ordSet.put(ord); + segmentGroupHeads[ord] = collectedGroup; + + for (int i = 0; i < sortsIndex.length; i++) { + if (fields[i].getType() == SortField.Type.SCORE) { + continue; + } + + collectedGroup.sortOrds[i] = sortsIndex[i].binarySearchLookup(collectedGroup.sortValues[i], scratchBytesRef); + } + } + } + } + + class GroupHead extends AbstractAllGroupHeadsCollector.GroupHead { + + BytesRef[] sortValues; + int[] sortOrds; + float[] scores; + + private GroupHead(int doc, BytesRef groupValue) throws IOException { + super(groupValue, doc + readerContext.docBase); + sortValues = new BytesRef[sortsIndex.length]; + sortOrds = new int[sortsIndex.length]; + if (sortContainsScore) { + scores = new float[sortsIndex.length]; + } + for (int i = 0; i < sortsIndex.length; i++) { + if (fields[i].getType() == SortField.Type.SCORE) { + scores[i] = scorer.score(); + } else { + sortValues[i] = sortsIndex[i].getTerm(doc, new BytesRef()); + sortOrds[i] = sortsIndex[i].getOrd(doc); + } + } + + } + + public int compare(int compIDX, int doc) throws IOException { + if (!sortContainsScore) { + return sortOrds[compIDX] - sortsIndex[compIDX].getOrd(doc); + } + + if (fields[compIDX].getType() == SortField.Type.SCORE) { + float score = scorer.score(); + if (scores[compIDX] < score) { + return -1; + } else if (scores[compIDX] > score) { + return 1; + } + return 0; + } else { + return sortOrds[compIDX] - sortsIndex[compIDX].getOrd(doc); + } + } + + public void updateDocHead(int doc) throws IOException { + for (int i = 0; i < sortsIndex.length; i++) { + if (fields[i].getType() == SortField.Type.SCORE) { + scores[i] = scorer.score(); + } else { + sortValues[i] = sortsIndex[i].getTerm(doc, new BytesRef()); + sortOrds[i] = sortsIndex[i].getOrd(doc); + } + } + this.doc = doc + readerContext.docBase; + } + + } + + } + + + // AbstractAllGroupHeadsCollector optimized for ord fields. + static class OrdAllGroupHeadsCollector extends TermAllGroupHeadsCollector { + + private final SentinelIntSet ordSet; + private final List collectedGroups; + private final SortField[] fields; + + private FieldCache.DocTermsIndex[] sortsIndex; + private GroupHead[] segmentGroupHeads; + + OrdAllGroupHeadsCollector(String groupField, Sort sortWithinGroup, int initialSize) { + super(groupField, sortWithinGroup.getSort().length); + ordSet = new SentinelIntSet(initialSize, -1); + collectedGroups = new ArrayList(initialSize); + + final SortField[] sortFields = sortWithinGroup.getSort(); + fields = new SortField[sortFields.length]; + sortsIndex = new FieldCache.DocTermsIndex[sortFields.length]; + for (int i = 0; i < sortFields.length; i++) { + reversed[i] = sortFields[i].getReverse() ? -1 : 1; + fields[i] = sortFields[i]; + } + } + + protected Collection getCollectedGroupHeads() { + return collectedGroups; + } + + public void setScorer(Scorer scorer) throws IOException { + } + + protected void retrieveGroupHeadAndAddIfNotExist(int doc) throws IOException { + int key = groupIndex.getOrd(doc); + GroupHead groupHead; + if (!ordSet.exists(key)) { + ordSet.put(key); + BytesRef term = key == 0 ? null : groupIndex.getTerm(doc, new BytesRef()); + groupHead = new GroupHead(doc, term); + collectedGroups.add(groupHead); + segmentGroupHeads[key] = groupHead; + temporalResult.stop = true; + } else { + temporalResult.stop = false; + groupHead = segmentGroupHeads[key]; + } + temporalResult.groupHead = groupHead; + } + + public void setNextReader(IndexReader.AtomicReaderContext context) throws IOException { + this.readerContext = context; + groupIndex = FieldCache.DEFAULT.getTermsIndex(context.reader, groupField); + for (int i = 0; i < fields.length; i++) { + sortsIndex[i] = FieldCache.DEFAULT.getTermsIndex(context.reader, fields[i].getField()); + } + + // Clear ordSet and fill it with previous encountered groups that can occur in the current segment. + ordSet.clear(); + segmentGroupHeads = new GroupHead[groupIndex.numOrd()]; + for (GroupHead collectedGroup : collectedGroups) { + int ord = groupIndex.binarySearchLookup(collectedGroup.groupValue, scratchBytesRef); + if (ord >= 0) { + ordSet.put(ord); + segmentGroupHeads[ord] = collectedGroup; + + for (int i = 0; i < sortsIndex.length; i++) { + if (fields[i].getType() == SortField.Type.SCORE) { + continue; + } + + collectedGroup.sortOrds[i] = sortsIndex[i].binarySearchLookup(collectedGroup.sortValues[i], scratchBytesRef); + } + } + } + } + + class GroupHead extends AbstractAllGroupHeadsCollector.GroupHead { + + BytesRef[] sortValues; + int[] sortOrds; + + private GroupHead(int doc, BytesRef groupValue) throws IOException { + super(groupValue, doc); + sortValues = new BytesRef[sortsIndex.length]; + sortOrds = new int[sortsIndex.length]; + for (int i = 0; i < sortsIndex.length; i++) { + sortValues[i] = sortsIndex[i].getTerm(doc, new BytesRef()); + sortOrds[i] = sortsIndex[i].getOrd(doc); + } + + this.doc = doc + readerContext.docBase; + } + + public int compare(int compIDX, int doc) throws IOException { + return sortOrds[compIDX] - sortsIndex[compIDX].getOrd(doc); + } + + public void updateDocHead(int doc) throws IOException { + for (int i = 0; i < sortsIndex.length; i++) { + sortValues[i] = sortsIndex[i].getTerm(doc, new BytesRef()); + sortOrds[i] = sortsIndex[i].getOrd(doc); + } + this.doc = doc + readerContext.docBase; + } + + } + + } + + + // AbstractAllGroupHeadsCollector optimized for scores. + static class ScoreAllGroupHeadsCollector extends TermAllGroupHeadsCollector { + + private final SentinelIntSet ordSet; + private final List collectedGroups; + private final SortField[] fields; + + private Scorer scorer; + private GroupHead[] segmentGroupHeads; + + ScoreAllGroupHeadsCollector(String groupField, Sort sortWithinGroup, int initialSize) { + super(groupField, sortWithinGroup.getSort().length); + ordSet = new SentinelIntSet(initialSize, -1); + collectedGroups = new ArrayList(initialSize); + + final SortField[] sortFields = sortWithinGroup.getSort(); + fields = new SortField[sortFields.length]; + for (int i = 0; i < sortFields.length; i++) { + reversed[i] = sortFields[i].getReverse() ? -1 : 1; + fields[i] = sortFields[i]; + } + } + + protected Collection getCollectedGroupHeads() { + return collectedGroups; + } + + public void setScorer(Scorer scorer) throws IOException { + this.scorer = scorer; + } + + protected void retrieveGroupHeadAndAddIfNotExist(int doc) throws IOException { + int key = groupIndex.getOrd(doc); + GroupHead groupHead; + if (!ordSet.exists(key)) { + ordSet.put(key); + BytesRef term = key == 0 ? null : groupIndex.getTerm(doc, new BytesRef()); + groupHead = new GroupHead(doc, term); + collectedGroups.add(groupHead); + segmentGroupHeads[key] = groupHead; + temporalResult.stop = true; + } else { + temporalResult.stop = false; + groupHead = segmentGroupHeads[key]; + } + temporalResult.groupHead = groupHead; + } + + public void setNextReader(IndexReader.AtomicReaderContext context) throws IOException { + this.readerContext = context; + groupIndex = FieldCache.DEFAULT.getTermsIndex(context.reader, groupField); + + // Clear ordSet and fill it with previous encountered groups that can occur in the current segment. + ordSet.clear(); + segmentGroupHeads = new GroupHead[groupIndex.numOrd()]; + for (GroupHead collectedGroup : collectedGroups) { + int ord = groupIndex.binarySearchLookup(collectedGroup.groupValue, scratchBytesRef); + if (ord >= 0) { + ordSet.put(ord); + segmentGroupHeads[ord] = collectedGroup; + } + } + } + + class GroupHead extends AbstractAllGroupHeadsCollector.GroupHead { + + float[] scores; + + private GroupHead(int doc, BytesRef groupValue) throws IOException { + super(groupValue, doc); + scores = new float[fields.length]; + float score = scorer.score(); + for (int i = 0; i < scores.length; i++) { + scores[i] = score; + } + + this.doc = doc + readerContext.docBase; + } + + public int compare(int compIDX, int doc) throws IOException { + float score = scorer.score(); + if (scores[compIDX] < score) { + return -1; + } else if (scores[compIDX] > score) { + return 1; + } + return 0; + } + + public void updateDocHead(int doc) throws IOException { + float score = scorer.score(); + for (int i = 0; i < scores.length; i++) { + scores[i] = score; + } + this.doc = doc + readerContext.docBase; + } + + } + + } + + +} Index: modules/grouping/src/java/org/apache/lucene/search/grouping/AbstractAllGroupHeadsCollector.java =================================================================== --- modules/grouping/src/java/org/apache/lucene/search/grouping/AbstractAllGroupHeadsCollector.java (revision ) +++ modules/grouping/src/java/org/apache/lucene/search/grouping/AbstractAllGroupHeadsCollector.java (revision ) @@ -0,0 +1,178 @@ +package org.apache.lucene.search.grouping; + +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import org.apache.lucene.index.IndexReader; +import org.apache.lucene.search.Collector; +import org.apache.lucene.util.OpenBitSet; + +import java.io.IOException; +import java.util.Collection; + +/** + * This collector specializes in collecting the most relevant document (group head) for each group that match the query. + * + * @lucene.experimental + */ +public abstract class AbstractAllGroupHeadsCollector extends Collector { + + protected final int[] reversed; + protected final int compIDXEnd; + protected final TemporalResult temporalResult; + + protected AbstractAllGroupHeadsCollector(int numberOfSorts) { + this.reversed = new int[numberOfSorts]; + this.compIDXEnd = numberOfSorts - 1; + temporalResult = new TemporalResult(); + } + + /** + * @param maxDoc The maxDoc of the top level {@link IndexReader}. + * @return an {@link OpenBitSet} containing all group heads. + */ + public OpenBitSet retrieveGroupHeads(int maxDoc) { + OpenBitSet bitSet = new OpenBitSet(maxDoc); + + Collection groupHeads = getCollectedGroupHeads(); + for (GroupHead groupHead : groupHeads) { + bitSet.fastSet(groupHead.doc); + } + + return bitSet; + } + + /** + * @return an int array containing all group heads. The size of the array is equal to number of collected unique groups. + */ + public int[] retrieveGroupHeads() { + Collection groupHeads = getCollectedGroupHeads(); + int[] docHeads = new int[groupHeads.size()]; + + int i = 0; + for (GroupHead groupHead : groupHeads) { + docHeads[i++] = groupHead.doc; + } + + return docHeads; + } + + /** + * @return the number of group heads found for a query. + */ + public int groupHeadsSize() { + return getCollectedGroupHeads().size(); + } + + /** + * Returns the group head and puts it into {@link #temporalResult}. + * If the group head wasn't encountered before then it will be added to the collected group heads. + *

+ * The {@link TemporalResult#stop} property will be true if the group head wasn't encountered before + * otherwise false. + * + * @param doc The document to retrieve the group head for. + * @throws IOException If I/O related errors occur + */ + protected abstract void retrieveGroupHeadAndAddIfNotExist(int doc) throws IOException; + + /** + * Subsequent calls should return the same group heads. + * + * @return The collected group heads + */ + protected abstract Collection getCollectedGroupHeads(); + + public void collect(int doc) throws IOException { + retrieveGroupHeadAndAddIfNotExist(doc); + if (temporalResult.stop) { + return; + } + GH groupHead = temporalResult.groupHead; + + // Ok now we need to check if the current doc is more relevant then current doc for this group + for (int compIDX = 0; ; compIDX++) { + final int c = reversed[compIDX] * groupHead.compare(compIDX, doc); + if (c < 0) { + // Definitely not competitive. So don't even bother to continue + return; + } else if (c > 0) { + // Definitely competitive. + break; + } else if (compIDX == compIDXEnd) { + // Here c=0. If we're at the last comparator, this doc is not + // competitive, since docs are visited in doc Id order, which means + // this doc cannot compete with any other document in the queue. + return; + } + } + groupHead.updateDocHead(doc); + } + + public boolean acceptsDocsOutOfOrder() { + return true; + } + + /** + * Contains the result of group head retrieval. + * To prevent new object creations of this class for every collect. + */ + protected class TemporalResult { + + protected GH groupHead; + protected boolean stop; + + } + + /** + * Represents a group head. A group head is the most relevant document for a particular group. + * The relevancy is based is usually based on the sort. + * + * The group head contains a group value with its associated most relevant document id. + */ + public static abstract class GroupHead { + + public final GROUP_VALUE_TYPE groupValue; + public int doc; + + protected GroupHead(GROUP_VALUE_TYPE groupValue, int doc) { + this.groupValue = groupValue; + this.doc = doc; + } + + /** + * Compares the specified document for a specified comparator against the current most relevant document. + * + * @param compIDX The comparator index of the specified comparator. + * @param doc The specified document. + * @return -1 if the specified document wasn't competitive against the current most relevant document, 1 if the + * specified document was competitive against the current most relevant document. Otherwise 0. + * @throws IOException If I/O related errors occur + */ + protected abstract int compare(int compIDX, int doc) throws IOException; + + /** + * Updates the current most relevant document with the specified document. + * + * @param doc The specified document + * @throws IOException If I/O related errors occur + */ + protected abstract void updateDocHead(int doc) throws IOException; + + } + +} Index: modules/grouping/src/test/org/apache/lucene/search/grouping/TestGrouping.java =================================================================== --- modules/grouping/src/test/org/apache/lucene/search/grouping/TestGrouping.java (revision 1145594) +++ modules/grouping/src/test/org/apache/lucene/search/grouping/TestGrouping.java (revision ) @@ -17,9 +17,6 @@ package org.apache.lucene.search.grouping; -import java.io.IOException; -import java.util.*; - import org.apache.lucene.analysis.MockAnalyzer; import org.apache.lucene.document.Document; import org.apache.lucene.document.Field; @@ -35,6 +32,9 @@ import org.apache.lucene.util.ReaderUtil; import org.apache.lucene.util._TestUtil; +import java.io.IOException; +import java.util.*; + // TODO // - should test relevance sort too // - test null @@ -649,18 +649,30 @@ final boolean doCache = random.nextBoolean(); final boolean doAllGroups = random.nextBoolean(); + final boolean doAllGroupHeads = random.nextBoolean(); if (VERBOSE) { System.out.println("TEST: groupSort=" + groupSort + " docSort=" + docSort + " searchTerm=" + searchTerm + " topNGroups=" + topNGroups + " groupOffset=" + groupOffset + " docOffset=" + docOffset + " doCache=" + doCache + " docsPerGroup=" + docsPerGroup + " doAllGroups=" + doAllGroups + " getScores=" + getScores + " getMaxScores=" + getMaxScores); } + List collectors = new ArrayList(); final TermAllGroupsCollector allGroupsCollector; if (doAllGroups) { allGroupsCollector = new TermAllGroupsCollector("group"); + collectors.add(allGroupsCollector); } else { allGroupsCollector = null; } + final AbstractAllGroupHeadsCollector allGroupHeadsCollector; + if (doAllGroupHeads) { + allGroupHeadsCollector = TermAllGroupHeadsCollector.create("group", docSort); + collectors.add(allGroupHeadsCollector); + } else { + allGroupHeadsCollector = null; + } + final TermFirstPassGroupingCollector c1 = new TermFirstPassGroupingCollector("group", groupSort, groupOffset+topNGroups); + collectors.add(c1); final CachingCollector cCache; final Collector c; @@ -673,24 +685,20 @@ } if (useWrappingCollector) { - if (doAllGroups) { - cCache = CachingCollector.create(c1, true, maxCacheMB); - c = MultiCollector.wrap(cCache, allGroupsCollector); + cCache = CachingCollector.create( + MultiCollector.wrap(collectors.toArray(new Collector[collectors.size()])), + true, + maxCacheMB + ); + c = cCache; - } else { + } else { - c = cCache = CachingCollector.create(c1, true, maxCacheMB); - } - } else { // Collect only into cache, then replay multiple times: c = cCache = CachingCollector.create(false, true, maxCacheMB); } } else { cCache = null; - if (doAllGroups) { - c = MultiCollector.wrap(c1, allGroupsCollector); - } else { - c = c1; + c = MultiCollector.wrap(collectors.toArray(new Collector[collectors.size()])); - } + } - } // Search top reader: final Query q = new TermQuery(new Term("content", searchTerm)); @@ -699,19 +707,12 @@ if (doCache && !useWrappingCollector) { if (cCache.isCached()) { // Replay for first-pass grouping - cCache.replay(c1); - if (doAllGroups) { - // Replay for all groups: - cCache.replay(allGroupsCollector); - } + cCache.replay(MultiCollector.wrap(collectors.toArray(new Collector[collectors.size()]))); } else { // Replay by re-running search: - s.search(new TermQuery(new Term("content", searchTerm)), c1); - if (doAllGroups) { - s.search(new TermQuery(new Term("content", searchTerm)), allGroupsCollector); + s.search(q, MultiCollector.wrap(collectors.toArray(new Collector[collectors.size()]))); - } - } + } + } - } final Collection> topGroups = c1.getTopGroups(groupOffset, fillFields); final TopGroups groupsResult; @@ -784,6 +785,40 @@ } assertEquals(docIDToID, expectedGroups, groupsResult, true, true, true, getScores); + if (doAllGroupHeads) { + int[] expectedGroupHeads = createExpectedGroupHeads(searchTerm, groupDocs, docSort); + int[] actualGroupHeads = allGroupHeadsCollector.retrieveGroupHeads(); + // The actual group heads contains Lucene ids. Need to change them into our id value. + for (int i = 0; i < actualGroupHeads.length; i++) { + actualGroupHeads[i] = docIDToID[actualGroupHeads[i]]; + } + // Allows us the easily iterate and assert the actual and expected results. + Arrays.sort(expectedGroupHeads); + Arrays.sort(actualGroupHeads); + + if (VERBOSE) { + System.out.println("Collector: " + allGroupHeadsCollector.getClass().getSimpleName()); + System.out.println("Sort: " + groupSort); + System.out.println("Sort within group: " + docSort); + System.out.println("Num group: " + numGroups); + System.out.println("Num doc: " + numDocs); + System.out.print("Expected: "); + for (int expectedDocId : expectedGroupHeads) { + System.out.print(expectedDocId + " "); + } + System.out.print("\nActual: "); + for (int actualDocId : actualGroupHeads) { + System.out.print(docIDToID[actualDocId] + " "); + } + System.out.println("\n================================================================================="); + } + + assertEquals(expectedGroupHeads.length, actualGroupHeads.length); + for (int i = 0; i < expectedGroupHeads.length; i++) { + assertEquals(expectedGroupHeads[i], actualGroupHeads[i]); + } + } + // Confirm merged shards match: assertEquals(docIDToID, expectedGroups, topGroupsShards, true, false, fillFields, getScores); if (topGroupsShards != null) { @@ -875,6 +910,33 @@ } } + private int[] createExpectedGroupHeads(String searchTerm, GroupDoc[] groupDocs, Sort docSort) throws IOException { + Map> groupHeads = new HashMap>(); + for (GroupDoc groupDoc : groupDocs) { + if (!groupDoc.content.startsWith(searchTerm)) { + continue; + } + + if (!groupHeads.containsKey(groupDoc.group)) { + List list = new ArrayList(); + list.add(groupDoc); + groupHeads.put(groupDoc.group, list); + continue; + } + groupHeads.get(groupDoc.group).add(groupDoc); + } + + int[] allGroupHeads = new int[groupHeads.size()]; + int i = 0; + for (List docs : groupHeads.values()) { + Collections.sort(docs, getComparator(docSort)); + + allGroupHeads[i++] = docs.get(0).id; + } + + return allGroupHeads; + } + private void verifyShards(int[] docStarts, TopGroups topGroups) { for(GroupDocs group : topGroups.groups) { for(int hitIDX=0;hitIDX