Index: modules/grouping/src/test/org/apache/lucene/search/grouping/TermAllGroupHeadsCollectorTest.java =================================================================== --- modules/grouping/src/test/org/apache/lucene/search/grouping/TermAllGroupHeadsCollectorTest.java (revision ) +++ modules/grouping/src/test/org/apache/lucene/search/grouping/TermAllGroupHeadsCollectorTest.java (revision ) @@ -0,0 +1,147 @@ +package org.apache.lucene.search.grouping; + +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import org.apache.lucene.analysis.MockAnalyzer; +import org.apache.lucene.document.Document; +import org.apache.lucene.document.Field; +import org.apache.lucene.index.RandomIndexWriter; +import org.apache.lucene.index.Term; +import org.apache.lucene.search.IndexSearcher; +import org.apache.lucene.search.Sort; +import org.apache.lucene.search.SortField; +import org.apache.lucene.search.TermQuery; +import org.apache.lucene.store.Directory; +import org.apache.lucene.util.LuceneTestCase; + +public class TermAllGroupHeadsCollectorTest extends LuceneTestCase { + + public void testRetrieveGroupHeads() throws Exception { + + final String groupField = "author"; + + Directory dir = newDirectory(); + RandomIndexWriter w = new RandomIndexWriter( + random, + dir, + newIndexWriterConfig(TEST_VERSION_CURRENT, + new MockAnalyzer(random)).setMergePolicy(newLogMergePolicy())); + // 0 + Document doc = new Document(); + doc.add(new Field(groupField, "author1", Field.Store.YES, Field.Index.ANALYZED)); + doc.add(new Field("content", "random text", Field.Store.YES, Field.Index.ANALYZED)); + doc.add(new Field("id", "1", Field.Store.YES, Field.Index.NOT_ANALYZED_NO_NORMS)); + w.addDocument(doc); + + // 1 + doc = new Document(); + doc.add(new Field(groupField, "author1", Field.Store.YES, Field.Index.ANALYZED)); + doc.add(new Field("content", "some more random text blob", Field.Store.YES, Field.Index.ANALYZED)); + doc.add(new Field("id", "2", Field.Store.YES, Field.Index.NOT_ANALYZED_NO_NORMS)); + w.addDocument(doc); + + // 2 + doc = new Document(); + doc.add(new Field(groupField, "author1", Field.Store.YES, Field.Index.ANALYZED)); + doc.add(new Field("content", "some more random textual data", Field.Store.YES, Field.Index.ANALYZED)); + doc.add(new Field("id", "3", Field.Store.YES, Field.Index.NOT_ANALYZED_NO_NORMS)); + w.addDocument(doc); + w.commit(); // To ensure a second segment + + // 3 + doc = new Document(); + doc.add(new Field(groupField, "author2", Field.Store.YES, Field.Index.ANALYZED)); + doc.add(new Field("content", "some random text", Field.Store.YES, Field.Index.ANALYZED)); + doc.add(new Field("id", "4", Field.Store.YES, Field.Index.NOT_ANALYZED_NO_NORMS)); + w.addDocument(doc); + + // 4 + doc = new Document(); + doc.add(new Field(groupField, "author3", Field.Store.YES, Field.Index.ANALYZED)); + doc.add(new Field("content", "some more random text", Field.Store.YES, Field.Index.ANALYZED)); + doc.add(new Field("id", "5", Field.Store.YES, Field.Index.NOT_ANALYZED_NO_NORMS)); + w.addDocument(doc); + + // 5 + doc = new Document(); + doc.add(new Field(groupField, "author3", Field.Store.YES, Field.Index.ANALYZED)); + doc.add(new Field("content", "random blob", Field.Store.YES, Field.Index.ANALYZED)); + doc.add(new Field("id", "6", Field.Store.YES, Field.Index.NOT_ANALYZED_NO_NORMS)); + w.addDocument(doc); + + // 6 -- no author field + doc = new Document(); + doc.add(new Field("content", "random word stuck in alot of other text", Field.Store.YES, Field.Index.ANALYZED)); + doc.add(new Field("id", "6", Field.Store.YES, Field.Index.NOT_ANALYZED_NO_NORMS)); + w.addDocument(doc); + + IndexSearcher indexSearcher = new IndexSearcher(w.getReader()); + w.close(); + + Sort sortWithinGroup = new Sort(new SortField("id", SortField.Type.INT, true)); + AbstractAllGroupHeadsCollector c1 = TermAllGroupHeadsCollector.create(groupField, sortWithinGroup); + indexSearcher.search(new TermQuery(new Term("content", "random")), c1); + assertTrue(arrayContains(new int[]{2, 3, 5, 6}, c1.retrieveGroupHeads())); + + AbstractAllGroupHeadsCollector c2 = TermAllGroupHeadsCollector.create(groupField, sortWithinGroup); + indexSearcher.search(new TermQuery(new Term("content", "some")), c2); + assertTrue(arrayContains(new int[]{2, 3, 4}, c2.retrieveGroupHeads())); + + AbstractAllGroupHeadsCollector c3 = TermAllGroupHeadsCollector.create(groupField, sortWithinGroup); + indexSearcher.search(new TermQuery(new Term("content", "blob")), c3); + assertTrue(arrayContains(new int[]{1, 5}, c3.retrieveGroupHeads())); + + // STRING sort type triggers different implementation + Sort sortWithinGroup2 = new Sort(new SortField("id", SortField.Type.STRING, true)); + AbstractAllGroupHeadsCollector c4 = TermAllGroupHeadsCollector.create(groupField, sortWithinGroup2); + indexSearcher.search(new TermQuery(new Term("content", "random")), c4); + assertTrue(arrayContains(new int[]{2, 3, 5, 6}, c4.retrieveGroupHeads())); + + Sort sortWithinGroup3 = new Sort(new SortField("id", SortField.Type.STRING, false)); + AbstractAllGroupHeadsCollector c5 = TermAllGroupHeadsCollector.create(groupField, sortWithinGroup3); + indexSearcher.search(new TermQuery(new Term("content", "random")), c5); + // 6 b/c higher doc id wins, even if order of field is in not in reverse. + assertTrue(arrayContains(new int[]{0, 3, 4, 6}, c5.retrieveGroupHeads())); + + indexSearcher.getIndexReader().close(); + dir.close(); + } + + + private boolean arrayContains(int[] expected, int[] result) { + if (expected.length != result.length) { + return false; + } + + for (int e : expected) { + boolean found = false; + for (int r : result) { + if (e == r) { + found = true; + } + } + + if (!found) { + return false; + } + } + + return true; + } + +} Index: modules/grouping/src/java/org/apache/lucene/search/grouping/TermAllGroupHeadsCollector.java =================================================================== --- modules/grouping/src/java/org/apache/lucene/search/grouping/TermAllGroupHeadsCollector.java (revision ) +++ modules/grouping/src/java/org/apache/lucene/search/grouping/TermAllGroupHeadsCollector.java (revision ) @@ -0,0 +1,531 @@ +package org.apache.lucene.search.grouping; + +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import org.apache.lucene.index.IndexReader; +import org.apache.lucene.search.*; +import org.apache.lucene.util.BytesRef; + +import java.io.IOException; +import java.util.*; + +/** + * + */ +public abstract class TermAllGroupHeadsCollector extends AbstractAllGroupHeadsCollector { + + private static final int DEFAULT_INITIAL_SIZE = 128; + + final String groupField; + final BytesRef scratchBytesRef = new BytesRef(); + + FieldCache.DocTermsIndex groupIndex; + IndexReader.AtomicReaderContext readerContext; + + protected TermAllGroupHeadsCollector(String groupField, int numberOfSorts) { + super(numberOfSorts); + this.groupField = groupField; + } + + /** + * Creates an AbstractAllGroupHeadsCollector instance based on the supplied arguments. + * This factory method decides with implementation is best suited. + * + * @param groupField The field to group by + * @param sortWithinGroup The sort within each group + * @return an AbstractAllGroupHeadsCollector instance based on the supplied arguments + * @throws IOException If I/O related errors occur + */ + public static AbstractAllGroupHeadsCollector create(String groupField, Sort sortWithinGroup) throws IOException { + return create(groupField, sortWithinGroup, DEFAULT_INITIAL_SIZE); + } + + public static AbstractAllGroupHeadsCollector create(String groupField, Sort sortWithinGroup, int initialSize) throws IOException { + boolean sortAllScore = true; + boolean sortAllFieldValue = true; + + for (SortField sortField : sortWithinGroup.getSort()) { + if (sortField.getType() == SortField.Type.SCORE) { + sortAllFieldValue = false; + } else if (needGeneralImpl(sortField)) { + return new GeneralAllGroupHeadsCollector(groupField, sortWithinGroup); + } else { + sortAllScore = false; + } + } + + if (sortAllScore) { + return new ScoreAllGroupHeadsCollector(groupField, sortWithinGroup, initialSize); + } else if (sortAllFieldValue) { + return new OrdAllGroupHeadsCollector(groupField, sortWithinGroup, initialSize); + } else { + return new OrdScoreAllGroupHeadsCollector(groupField, sortWithinGroup, initialSize); + } + } + + // Returns when a sort field needs the general impl. + private static boolean needGeneralImpl(SortField sortField) { + SortField.Type sortType = sortField.getType(); + // Note (MvG): We can also make an optimized impl when sorting is SortField.DOC + return sortType != SortField.Type.STRING_VAL && sortType != SortField.Type.STRING && sortType != SortField.Type.SCORE; + } + + // A general impl that works for any group sort. + static class GeneralAllGroupHeadsCollector extends TermAllGroupHeadsCollector { + + private final Sort sortWithinGroup; + private final Map groups; + + private Scorer scorer; + + GeneralAllGroupHeadsCollector(String groupField, Sort sortWithinGroup) throws IOException { + super(groupField, sortWithinGroup.getSort().length); + this.sortWithinGroup = sortWithinGroup; + groups = new HashMap(); + + final SortField[] sortFields = sortWithinGroup.getSort(); + for (int i = 0; i < sortFields.length; i++) { + reversed[i] = sortFields[i].getReverse() ? -1 : 1; + } + } + + protected void retrieveGroupHeadAndAddIfNotExist(int doc) throws IOException { + final int ord = groupIndex.getOrd(doc); + final BytesRef groupValue = ord == 0 ? null : groupIndex.lookup(ord, scratchBytesRef); + GroupHead groupHead = groups.get(groupValue); + if (groupHead == null) { + groupHead = new GroupHead(groupValue, sortWithinGroup, doc); + groups.put(groupValue == null ? null : new BytesRef(groupValue), groupHead); + temporalResult.stop = true; + } else { + temporalResult.stop = false; + } + temporalResult.groupHead = groupHead; + } + + protected Collection getCollectedGroupHeads() { + return groups.values(); + } + + public void setNextReader(IndexReader.AtomicReaderContext context) throws IOException { + this.readerContext = context; + groupIndex = FieldCache.DEFAULT.getTermsIndex(context.reader, groupField); + + for (GroupHead groupHead : groups.values()) { + for (int i = 0; i < groupHead.comparators.length; i++) { + groupHead.comparators[i] = groupHead.comparators[i].setNextReader(context); + } + } + } + + public void setScorer(Scorer scorer) throws IOException { + this.scorer = scorer; + for (GroupHead groupHead : groups.values()) { + for (FieldComparator comparator : groupHead.comparators) { + comparator.setScorer(scorer); + } + } + } + + class GroupHead extends AbstractAllGroupHeadsCollector.GroupHead { + + final FieldComparator[] comparators; + + private GroupHead(BytesRef groupValue, Sort sort, int doc) throws IOException { + super(groupValue, doc + readerContext.docBase); + final SortField[] sortFields = sort.getSort(); + comparators = new FieldComparator[sortFields.length]; + for (int i = 0; i < sortFields.length; i++) { + final SortField sortField = sortFields[i]; + comparators[i] = sortField.getComparator(1, i); + comparators[i].setNextReader(readerContext); + comparators[i].setScorer(scorer); + comparators[i].copy(0, doc); + comparators[i].setBottom(0); + } + } + + public int compare(int compIDX, int doc) throws IOException { + return comparators[compIDX].compareBottom(doc); + } + + public void updateDocHead(int doc) throws IOException { + for (FieldComparator comparator : comparators) { + comparator.copy(0, doc); + comparator.setBottom(0); + } + this.doc = doc + readerContext.docBase; + } + } + } + + + // AbstractAllGroupHeadsCollector optimized for ord fields and scores. + static class OrdScoreAllGroupHeadsCollector extends TermAllGroupHeadsCollector { + + private final SentinelIntSet ordSet; + private final List collectedGroups; + private final SortField[] fields; + private final boolean sortContainsScore; + + private FieldCache.DocTermsIndex[] sortsIndex; + private Scorer scorer; + private GroupHead[] segmentGroupHeads; + + OrdScoreAllGroupHeadsCollector(String groupField, Sort sortWithinGroup, int initialSize) { + super(groupField, sortWithinGroup.getSort().length); + ordSet = new SentinelIntSet(initialSize, -1); + collectedGroups = new ArrayList(initialSize); + + final SortField[] sortFields = sortWithinGroup.getSort(); + fields = new SortField[sortFields.length]; + sortsIndex = new FieldCache.DocTermsIndex[sortFields.length]; + boolean sortContainsScore = false; + for (int i = 0; i < sortFields.length; i++) { + reversed[i] = sortFields[i].getReverse() ? -1 : 1; + fields[i] = sortFields[i]; + if (sortFields[i].getType() == SortField.Type.SCORE) { + sortContainsScore = true; + } + } + this.sortContainsScore = sortContainsScore; + } + + protected Collection getCollectedGroupHeads() { + return collectedGroups; + } + + public void setScorer(Scorer scorer) throws IOException { + this.scorer = scorer; + } + + protected void retrieveGroupHeadAndAddIfNotExist(int doc) throws IOException { + int key = groupIndex.getOrd(doc); + GroupHead groupHead; + if (!ordSet.exists(key)) { + ordSet.put(key); + BytesRef term = key == 0 ? null : groupIndex.getTerm(doc, new BytesRef()); + groupHead = new GroupHead(doc, term); + collectedGroups.add(groupHead); + segmentGroupHeads[key] = groupHead; + temporalResult.stop = true; + } else { + temporalResult.stop = false; + groupHead = segmentGroupHeads[key]; + } + temporalResult.groupHead = groupHead; + } + + public void setNextReader(IndexReader.AtomicReaderContext context) throws IOException { + this.readerContext = context; + groupIndex = FieldCache.DEFAULT.getTermsIndex(context.reader, groupField); + for (int i = 0; i < fields.length; i++) { + sortsIndex[i] = FieldCache.DEFAULT.getTermsIndex(context.reader, fields[i].getField()); + } + + // Clear ordSet and fill it with previous encountered groups that can occur in the current segment. + ordSet.clear(); + segmentGroupHeads = new GroupHead[groupIndex.numOrd()]; + for (GroupHead collectedGroup : collectedGroups) { + int ord = groupIndex.binarySearchLookup(collectedGroup.groupValue, scratchBytesRef); + if (ord >= 0) { + ordSet.put(ord); + segmentGroupHeads[ord] = collectedGroup; + + for (int i = 0; i < sortsIndex.length; i++) { + if (fields[i].getType() == SortField.Type.SCORE) { + continue; + } + + collectedGroup.sortOrds[i] = sortsIndex[i].binarySearchLookup(collectedGroup.sortValues[i], scratchBytesRef); + } + } + } + } + + class GroupHead extends AbstractAllGroupHeadsCollector.GroupHead { + + BytesRef[] sortValues; + int[] sortOrds; + float[] scores; + + private GroupHead(int doc, BytesRef groupValue) throws IOException { + super(groupValue, doc + readerContext.docBase); + sortValues = new BytesRef[sortsIndex.length]; + sortOrds = new int[sortsIndex.length]; + if (sortContainsScore) { + scores = new float[sortsIndex.length]; + } + for (int i = 0; i < sortsIndex.length; i++) { + if (fields[i].getType() == SortField.Type.SCORE) { + scores[i] = scorer.score(); + } else { + sortValues[i] = sortsIndex[i].getTerm(doc, new BytesRef()); + sortOrds[i] = sortsIndex[i].getOrd(doc); + } + } + + } + + public int compare(int compIDX, int doc) throws IOException { + if (!sortContainsScore) { + return sortOrds[compIDX] - sortsIndex[compIDX].getOrd(doc); + } + + if (fields[compIDX].getType() == SortField.Type.SCORE) { + float score = scorer.score(); + if (scores[compIDX] < score) { + return -1; + } else if (scores[compIDX] > score) { + return 1; + } + return 0; + } else { + return sortOrds[compIDX] - sortsIndex[compIDX].getOrd(doc); + } + } + + public void updateDocHead(int doc) throws IOException { + for (int i = 0; i < sortsIndex.length; i++) { + if (fields[i].getType() == SortField.Type.SCORE) { + scores[i] = scorer.score(); + } else { + sortValues[i] = sortsIndex[i].getTerm(doc, new BytesRef()); + sortOrds[i] = sortsIndex[i].getOrd(doc); + } + } + this.doc = doc + readerContext.docBase; + } + + } + + } + + + // AbstractAllGroupHeadsCollector optimized for ord fields. + static class OrdAllGroupHeadsCollector extends TermAllGroupHeadsCollector { + + private final SentinelIntSet ordSet; + private final List collectedGroups; + private final SortField[] fields; + + private FieldCache.DocTermsIndex[] sortsIndex; + private GroupHead[] segmentGroupHeads; + + OrdAllGroupHeadsCollector(String groupField, Sort sortWithinGroup, int initialSize) { + super(groupField, sortWithinGroup.getSort().length); + ordSet = new SentinelIntSet(initialSize, -1); + collectedGroups = new ArrayList(initialSize); + + final SortField[] sortFields = sortWithinGroup.getSort(); + fields = new SortField[sortFields.length]; + sortsIndex = new FieldCache.DocTermsIndex[sortFields.length]; + for (int i = 0; i < sortFields.length; i++) { + reversed[i] = sortFields[i].getReverse() ? -1 : 1; + fields[i] = sortFields[i]; + } + } + + protected Collection getCollectedGroupHeads() { + return collectedGroups; + } + + public void setScorer(Scorer scorer) throws IOException { + } + + protected void retrieveGroupHeadAndAddIfNotExist(int doc) throws IOException { + int key = groupIndex.getOrd(doc); + GroupHead groupHead; + if (!ordSet.exists(key)) { + ordSet.put(key); + BytesRef term = key == 0 ? null : groupIndex.getTerm(doc, new BytesRef()); + groupHead = new GroupHead(doc, term); + collectedGroups.add(groupHead); + segmentGroupHeads[key] = groupHead; + temporalResult.stop = true; + } else { + temporalResult.stop = false; + groupHead = segmentGroupHeads[key]; + } + temporalResult.groupHead = groupHead; + } + + public void setNextReader(IndexReader.AtomicReaderContext context) throws IOException { + this.readerContext = context; + groupIndex = FieldCache.DEFAULT.getTermsIndex(context.reader, groupField); + for (int i = 0; i < fields.length; i++) { + sortsIndex[i] = FieldCache.DEFAULT.getTermsIndex(context.reader, fields[i].getField()); + } + + // Clear ordSet and fill it with previous encountered groups that can occur in the current segment. + ordSet.clear(); + segmentGroupHeads = new GroupHead[groupIndex.numOrd()]; + for (GroupHead collectedGroup : collectedGroups) { + int ord = groupIndex.binarySearchLookup(collectedGroup.groupValue, scratchBytesRef); + if (ord >= 0) { + ordSet.put(ord); + segmentGroupHeads[ord] = collectedGroup; + + for (int i = 0; i < sortsIndex.length; i++) { + if (fields[i].getType() == SortField.Type.SCORE) { + continue; + } + + collectedGroup.sortOrds[i] = sortsIndex[i].binarySearchLookup(collectedGroup.sortValues[i], scratchBytesRef); + } + } + } + } + + class GroupHead extends AbstractAllGroupHeadsCollector.GroupHead { + + BytesRef[] sortValues; + int[] sortOrds; + + private GroupHead(int doc, BytesRef groupValue) throws IOException { + super(groupValue, doc); + sortValues = new BytesRef[sortsIndex.length]; + sortOrds = new int[sortsIndex.length]; + for (int i = 0; i < sortsIndex.length; i++) { + sortValues[i] = sortsIndex[i].getTerm(doc, new BytesRef()); + sortOrds[i] = sortsIndex[i].getOrd(doc); + } + + this.doc = doc + readerContext.docBase; + } + + public int compare(int compIDX, int doc) throws IOException { + return sortOrds[compIDX] - sortsIndex[compIDX].getOrd(doc); + } + + public void updateDocHead(int doc) throws IOException { + for (int i = 0; i < sortsIndex.length; i++) { + sortValues[i] = sortsIndex[i].getTerm(doc, new BytesRef()); + sortOrds[i] = sortsIndex[i].getOrd(doc); + } + this.doc = doc + readerContext.docBase; + } + + } + + } + + + // AbstractAllGroupHeadsCollector optimized for scores. + static class ScoreAllGroupHeadsCollector extends TermAllGroupHeadsCollector { + + private final SentinelIntSet ordSet; + private final List collectedGroups; + private final SortField[] fields; + + private Scorer scorer; + private GroupHead[] segmentGroupHeads; + + ScoreAllGroupHeadsCollector(String groupField, Sort sortWithinGroup, int initialSize) { + super(groupField, sortWithinGroup.getSort().length); + ordSet = new SentinelIntSet(initialSize, -1); + collectedGroups = new ArrayList(initialSize); + + final SortField[] sortFields = sortWithinGroup.getSort(); + fields = new SortField[sortFields.length]; + for (int i = 0; i < sortFields.length; i++) { + reversed[i] = sortFields[i].getReverse() ? -1 : 1; + fields[i] = sortFields[i]; + } + } + + protected Collection getCollectedGroupHeads() { + return collectedGroups; + } + + public void setScorer(Scorer scorer) throws IOException { + this.scorer = scorer; + } + + protected void retrieveGroupHeadAndAddIfNotExist(int doc) throws IOException { + int key = groupIndex.getOrd(doc); + GroupHead groupHead; + if (!ordSet.exists(key)) { + ordSet.put(key); + BytesRef term = key == 0 ? null : groupIndex.getTerm(doc, new BytesRef()); + groupHead = new GroupHead(doc, term); + collectedGroups.add(groupHead); + segmentGroupHeads[key] = groupHead; + temporalResult.stop = true; + } else { + temporalResult.stop = false; + groupHead = segmentGroupHeads[key]; + } + temporalResult.groupHead = groupHead; + } + + public void setNextReader(IndexReader.AtomicReaderContext context) throws IOException { + this.readerContext = context; + groupIndex = FieldCache.DEFAULT.getTermsIndex(context.reader, groupField); + + // Clear ordSet and fill it with previous encountered groups that can occur in the current segment. + ordSet.clear(); + segmentGroupHeads = new GroupHead[groupIndex.numOrd()]; + for (GroupHead collectedGroup : collectedGroups) { + int ord = groupIndex.binarySearchLookup(collectedGroup.groupValue, scratchBytesRef); + if (ord >= 0) { + ordSet.put(ord); + segmentGroupHeads[ord] = collectedGroup; + } + } + } + + class GroupHead extends AbstractAllGroupHeadsCollector.GroupHead { + + float[] scores; + + private GroupHead(int doc, BytesRef groupValue) throws IOException { + super(groupValue, doc); + scores = new float[fields.length]; + float score = scorer.score(); + for (int i = 0; i < scores.length; i++) { + scores[i] = score; + } + + this.doc = doc + readerContext.docBase; + } + + public int compare(int compIDX, int doc) throws IOException { + float score = scorer.score(); + if (scores[compIDX] < score) { + return -1; + } else if (scores[compIDX] > score) { + return 1; + } + return 0; + } + + public void updateDocHead(int doc) throws IOException { + float score = scorer.score(); + for (int i = 0; i < scores.length; i++) { + scores[i] = score; + } + this.doc = doc + readerContext.docBase; + } + + } + + } + + +} Index: modules/grouping/src/java/org/apache/lucene/search/grouping/AbstractAllGroupHeadsCollector.java =================================================================== --- modules/grouping/src/java/org/apache/lucene/search/grouping/AbstractAllGroupHeadsCollector.java (revision ) +++ modules/grouping/src/java/org/apache/lucene/search/grouping/AbstractAllGroupHeadsCollector.java (revision ) @@ -0,0 +1,176 @@ +package org.apache.lucene.search.grouping; + +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import org.apache.lucene.index.IndexReader; +import org.apache.lucene.search.Collector; +import org.apache.lucene.util.OpenBitSet; + +import java.io.IOException; +import java.util.Collection; + +/** + * This collector specializes in collecting the most relevant document (group head) for each group that match the query. + * + * @lucene.experimental + */ +public abstract class AbstractAllGroupHeadsCollector extends Collector { + + protected final int[] reversed; + protected final int compIDXEnd; + protected final TemporalResult temporalResult; + + protected AbstractAllGroupHeadsCollector(int numberOfSorts) { + this.reversed = new int[numberOfSorts]; + this.compIDXEnd = numberOfSorts - 1; + temporalResult = new TemporalResult(); + } + + /** + * @param maxDoc The maxDoc of the top level {@link IndexReader}. + * @return an {@link OpenBitSet} containing all group heads. + */ + public OpenBitSet retrieveGroupHeads(int maxDoc) { + OpenBitSet bitSet = new OpenBitSet(maxDoc); + + Collection groupHeads = getCollectedGroupHeads(); + for (GroupHead groupHead : groupHeads) { + bitSet.fastSet(groupHead.doc); + } + + return bitSet; + } + + /** + * @return an int array containing all group heads. The size of the array is equal to number of collected unique groups. + */ + public int[] retrieveGroupHeads() { + Collection groupHeads = getCollectedGroupHeads(); + int[] docHeads = new int[groupHeads.size()]; + + int i = 0; + for (GroupHead groupHead : groupHeads) { + docHeads[i++] = groupHead.doc; + } + + return docHeads; + } + + /** + * @return the number of group heads found for a query. + */ + public int groupHeadsSize() { + return getCollectedGroupHeads().size(); + } + + /** + * Returns the group head and puts it into {@link #temporalResult}. + * If the group head wasn't encountered before then it will be added to the collected group heads. + *

+ * The {@link TemporalResult#stop} property will be true if the group head wasn't encountered before + * otherwise false. + * + * @param doc The document to retrieve the group head for. + * @throws IOException If I/O related errors occur + */ + protected abstract void retrieveGroupHeadAndAddIfNotExist(int doc) throws IOException; + + /** + * @return The collected group heads + */ + protected abstract Collection getCollectedGroupHeads(); + + public void collect(int doc) throws IOException { + retrieveGroupHeadAndAddIfNotExist(doc); + if (temporalResult.stop) { + return; + } + GH groupHead = temporalResult.groupHead; + + // Ok now we need to check if the current doc is more relevant then current doc for this group + for (int compIDX = 0; ; compIDX++) { + final int c = reversed[compIDX] * groupHead.compare(compIDX, doc); + if (c < 0) { + // Definitely not competitive. So don't even bother to continue + return; + } else if (c > 0) { + // Definitely competitive. + break; + } else if (compIDX == compIDXEnd) { + // Here c=0. If we're at the last comparator, this doc is not + // competitive, since docs are visited in doc Id order, which means + // this doc cannot compete with any other document in the queue. + return; + } + } + groupHead.updateDocHead(doc); + } + + public boolean acceptsDocsOutOfOrder() { + return true; + } + + /** + * Contains the result of group head retrieval. + * To prevent new object creations of this class for every collect. + */ + protected class TemporalResult { + + protected GH groupHead; + protected boolean stop; + + } + + /** + * Represents a group head. A group head is the most relevant document for a particular group. + * The relevancy is based is usually based on the sort. + * + * The group head contains a group value with its associated most relevant document id. + */ + public static abstract class GroupHead { + + public final GROUP_VALUE_TYPE groupValue; + public int doc; + + protected GroupHead(GROUP_VALUE_TYPE groupValue, int doc) { + this.groupValue = groupValue; + this.doc = doc; + } + + /** + * Compares the specified document for a specified comparator against the current most relevant document. + * + * @param compIDX The comparator index of the specified comparator. + * @param doc The specified document. + * @return -1 if the specified document wasn't competitive against the current most relevant document, 1 if the + * specified document was competitive against the current most relevant document. Otherwise 0. + * @throws IOException If I/O related errors occur + */ + protected abstract int compare(int compIDX, int doc) throws IOException; + + /** + * Updates the current most relevant document with the specified document. + * + * @param doc The specified document + * @throws IOException If I/O related errors occur + */ + protected abstract void updateDocHead(int doc) throws IOException; + + } + +}