Index: modules/grouping/src/java/org/apache/lucene/search/grouping/AllGroupHeadsCollector.java =================================================================== --- modules/grouping/src/java/org/apache/lucene/search/grouping/AllGroupHeadsCollector.java (revision ) +++ modules/grouping/src/java/org/apache/lucene/search/grouping/AllGroupHeadsCollector.java (revision ) @@ -0,0 +1,699 @@ +package org.apache.lucene.search.grouping; + +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import org.apache.lucene.index.IndexReader; +import org.apache.lucene.search.*; +import org.apache.lucene.util.BytesRef; +import org.apache.lucene.util.OpenBitSet; + +import java.io.IOException; +import java.util.*; + +/** + * This collector specializes in collecting the most relevant document (group head) for each group that match the query. + * + * @lucene.experimental + */ +public abstract class AllGroupHeadsCollector extends Collector { + + private static final int DEFAULT_INITIAL_SIZE = 128; + + /** + * @param maxDoc The maxDoc of the top level {@link IndexReader}. + * @return an {@link OpenBitSet} containing all group heads. + */ + public abstract OpenBitSet retrieveAllGroupHeads(int maxDoc); + + /** + * @return an int array containing all group heads. The size of the array is equal to number of collected unique groups. + */ + public abstract int[] retrieveAllGroupHeads(); + + /** + * Creates an AllGroupHeadsCollector instance based on the supplied arguments. + * This factory method decides with implementation is best suited. + * + * @param groupField The field to group by + * @param sortWithinGroup The sort within each group + * @return an AllGroupHeadsCollector instance based on the supplied arguments + * @throws IOException If I/O related errors occur + */ + public static AllGroupHeadsCollector create(String groupField, Sort sortWithinGroup) throws IOException { + return create(groupField, sortWithinGroup, DEFAULT_INITIAL_SIZE); + } + + public static AllGroupHeadsCollector create(String groupField, Sort sortWithinGroup, int initialSize) throws IOException { + boolean sortAllScore = true; + boolean sortAllFieldValue = true; + + for (SortField sortField : sortWithinGroup.getSort()) { + if (sortField.getType() == SortField.SCORE) { + sortAllFieldValue = false; + } else if (sortField.getType() == SortField.CUSTOM || sortField.getType() == SortField.DOC) { + return new GeneralAllGroupHeadsCollector(groupField, sortWithinGroup); + } else { + sortAllScore = false; + } + } + + if (sortAllScore) { + return new ScoreAllGroupHeadsCollector(groupField, sortWithinGroup, initialSize); + } else if(sortAllFieldValue) { + return new OrdAllGroupHeadsCollector(groupField, sortWithinGroup, initialSize); + } else { + return new OrdScoreAllGroupHeadsCollector(groupField, sortWithinGroup, initialSize); + } + } + + // A general impl that works for any group sort. + static class GeneralAllGroupHeadsCollector extends AllGroupHeadsCollector { + + private final String groupField; + private final Sort sortWithinGroup; + private final Map groups; + private final BytesRef scratchBytesRef = new BytesRef(); + private final int[] reversed; + private final int compIDXEnd; + + private FieldCache.DocTermsIndex index; + private IndexReader.AtomicReaderContext readerContext; + private Scorer scorer; + + GeneralAllGroupHeadsCollector(String groupField, Sort sortWithinGroup) throws IOException { + this.groupField = groupField; + this.sortWithinGroup = sortWithinGroup; + groups = new HashMap(); + + final SortField[] sortFields = sortWithinGroup.getSort(); + compIDXEnd = sortFields.length - 1; + reversed = new int[sortFields.length]; + for (int i = 0; i < sortFields.length; i++) { + reversed[i] = sortFields[i].getReverse() ? -1 : 1; + } + } + + public void collect(int doc) throws IOException { + final int ord = index.getOrd(doc); + final BytesRef groupValue = ord == 0 ? null : index.lookup(ord, scratchBytesRef); + GroupHead groupHead = groups.get(groupValue); + if (groupHead == null) { + groups.put(new BytesRef(groupValue), new GroupHead(sortWithinGroup, doc)); + return; + } + + // Ok now we need to check if the current doc is more relevant then current doc for this group + for (int compIDX = 0; ; compIDX++) { + final int c = reversed[compIDX] * groupHead.comparators[compIDX].compareBottom(doc); + if (c < 0) { + // Definitely not competitive. So don't even bother to continue + return; + } else if (c > 0) { + // Definitely competitive. + break; + } else if (compIDX == compIDXEnd) { + // Here c=0. If we're at the last comparator, this doc is not + // competitive, since docs are visited in doc Id order, which means + // this doc cannot compete with any other document in the queue. + return; + } + } + + groupHead.updateHead(doc); + } + + public OpenBitSet retrieveAllGroupHeads(int maxDoc) { + OpenBitSet bitSet = new OpenBitSet(maxDoc); + + for (GroupHead groupHead : groups.values()) { + bitSet.fastSet(groupHead.docId); + } + + return bitSet; + } + + public int[] retrieveAllGroupHeads() { + int[] docHeads = new int[groups.size()]; + + int i = 0; + for (GroupHead groupHead : groups.values()) { + docHeads[i++] = groupHead.docId; + } + + return docHeads; + } + + public void setNextReader(IndexReader.AtomicReaderContext context) throws IOException { + this.readerContext = context; + index = FieldCache.DEFAULT.getTermsIndex(context.reader, groupField); + + for (GroupHead groupHead : groups.values()) { + for (int i = 0; i < groupHead.comparators.length; i++) { + groupHead.comparators[i] = groupHead.comparators[i].setNextReader(context); + } + } + } + + public void setScorer(Scorer scorer) throws IOException { + this.scorer = scorer; + for (GroupHead groupHead : groups.values()) { + for (FieldComparator comparator : groupHead.comparators) { + comparator.setScorer(scorer); + } + } + } + + public boolean acceptsDocsOutOfOrder() { + return false; + } + + private class GroupHead { + + final FieldComparator[] comparators; + int docId; + + private GroupHead(Sort sort, int doc) throws IOException { + final SortField[] sortFields = sort.getSort(); + comparators = new FieldComparator[sortFields.length]; + for (int i = 0; i < sortFields.length; i++) { + final SortField sortField = sortFields[i]; + comparators[i] = sortField.getComparator(1, i); + comparators[i].setNextReader(readerContext); + comparators[i].setScorer(scorer); + comparators[i].copy(0, doc); + comparators[i].setBottom(0); + } + + this.docId = doc + readerContext.docBase; + } + + void updateHead(int doc) throws IOException { + for (FieldComparator comparator : comparators) { + comparator.copy(0, doc); + comparator.setBottom(0); + } + this.docId = doc + readerContext.docBase; + } + } + } + + // AllGroupHeadsCollector optimized for ord fields and scores. + static class OrdScoreAllGroupHeadsCollector extends AllGroupHeadsCollector { + + private final String groupField; + private final SentinelIntSet ordSet; + private final List collectedGroups; + private final int[] reversed; + private final SortField[] fields; + private final int compIDXEnd; + private final boolean sortContainsScore; + + private final BytesRef spareBytesRef = new BytesRef(); + private FieldCache.DocTermsIndex groupIndex; + private FieldCache.DocTermsIndex[] sortsIndex; + private Scorer scorer; + private GroupHead[] segmentGroupHeads; + private IndexReader.AtomicReaderContext readerContext; + + OrdScoreAllGroupHeadsCollector(String groupField, Sort sortWithinGroup, int initialSize) { + this.groupField = groupField; + ordSet = new SentinelIntSet(initialSize, -1); + collectedGroups = new ArrayList(initialSize); + + final SortField[] sortFields = sortWithinGroup.getSort(); + compIDXEnd = sortFields.length - 1; + reversed = new int[sortFields.length]; + fields = new SortField[sortFields.length]; + sortsIndex = new FieldCache.DocTermsIndex[sortFields.length]; + boolean sortContainsScore = false; + for (int i = 0; i < sortFields.length; i++) { + reversed[i] = sortFields[i].getReverse() ? -1 : 1; + fields[i] = sortFields[i]; + if (sortFields[i].getType() == SortField.SCORE) { + sortContainsScore = true; + } + } + this.sortContainsScore = sortContainsScore; + } + + public OpenBitSet retrieveAllGroupHeads(int maxDoc) { + OpenBitSet bitSet = new OpenBitSet(maxDoc); + + for (GroupHead groupHead : collectedGroups) { + bitSet.fastSet(groupHead.docId); + } + + return bitSet; + } + + public int[] retrieveAllGroupHeads() { + int[] groupHeads = new int[collectedGroups.size()]; + + int i = 0; + for (GroupHead collectedGroup : collectedGroups) { + groupHeads[i++] = collectedGroup.docId; + } + + return groupHeads; + } + + public void setScorer(Scorer scorer) throws IOException { + this.scorer = scorer; + } + + public void collect(int doc) throws IOException { + int key = groupIndex.getOrd(doc); + if (!ordSet.exists(key)) { + ordSet.put(key); + BytesRef term = key == 0 ? null : groupIndex.getTerm(doc, new BytesRef()); + GroupHead groupHead = new GroupHead(doc, term); + collectedGroups.add(groupHead); + segmentGroupHeads[key] = groupHead; + return; + } + + // Ok now we need to check if the current doc is more relevant then current doc for this group + GroupHead groupHead = segmentGroupHeads[key]; + for (int compIDX = 0; ; compIDX++) { + final int c = reversed[compIDX] * compare(groupHead, doc, compIDX); + if (c < 0) { + // Definitely not competitive. So don't even bother to continue + return; + } else if (c > 0) { + // Definitely competitive. + break; + } else if (compIDX == compIDXEnd) { + // Here c=0. If we're at the last comparator, this doc is not + // competitive, since docs are visited in doc Id order, which means + // this doc cannot compete with any other document in the queue. + return; + } + } + + groupHead.updateHead(doc); + } + + private int compare(GroupHead groupHead, int doc, int compIDX) throws IOException { + if (!sortContainsScore) { + return groupHead.sortOrds[compIDX] - sortsIndex[compIDX].getOrd(doc); + } + + if (fields[compIDX].getType() == SortField.SCORE) { + float score = scorer.score(); + if (groupHead.scores[compIDX] < score) { + return -1; + } else if (groupHead.scores[compIDX] > score) { + return 1; + } + return 0; + } else { + return groupHead.sortOrds[compIDX] - sortsIndex[compIDX].getOrd(doc); + } + } + + public void setNextReader(IndexReader.AtomicReaderContext context) throws IOException { + this.readerContext = context; + groupIndex = FieldCache.DEFAULT.getTermsIndex(context.reader, groupField); + for (int i = 0; i < fields.length; i++) { + sortsIndex[i] = FieldCache.DEFAULT.getTermsIndex(context.reader, fields[i].getField()); + } + + // Clear ordSet and fill it with previous encountered groups that can occur in the current segment. + ordSet.clear(); + segmentGroupHeads = new GroupHead[groupIndex.numOrd()]; + for (GroupHead collectedGroup : collectedGroups) { + int ord = groupIndex.binarySearchLookup(collectedGroup.groupValue, spareBytesRef); + if (ord >= 0) { + ordSet.put(ord); + segmentGroupHeads[ord] = collectedGroup; + + for (int i = 0; i < sortsIndex.length; i++) { + if (fields[i].getType() == SortField.SCORE) { + continue; + } + + collectedGroup.sortOrds[i] = sortsIndex[i].binarySearchLookup(collectedGroup.sortValues[i], spareBytesRef); + } + } + } + } + + public boolean acceptsDocsOutOfOrder() { + return false; + } + + private class GroupHead { + + BytesRef groupValue; + BytesRef[] sortValues; + int[] sortOrds; + float[] scores; + int docId; + + private GroupHead(int doc, BytesRef groupValue) throws IOException { + this.groupValue = groupValue; + sortValues = new BytesRef[sortsIndex.length]; + sortOrds = new int[sortsIndex.length]; + if (sortContainsScore) { + scores = new float[sortsIndex.length]; + } + for (int i = 0; i < sortsIndex.length; i++) { + if (fields[i].getType() == SortField.SCORE) { + scores[i] = scorer.score(); + } else { + sortValues[i] = sortsIndex[i].getTerm(doc, new BytesRef()); + sortOrds[i] = sortsIndex[i].getOrd(doc); + } + } + + this.docId = doc + readerContext.docBase; + } + + private void updateHead(int doc) throws IOException { + for (int i = 0; i < sortsIndex.length; i++) { + if (fields[i].getType() == SortField.SCORE) { + scores[i] = scorer.score(); + } else { + sortValues[i] = sortsIndex[i].getTerm(doc, new BytesRef()); + sortOrds[i] = sortsIndex[i].getOrd(doc); + } + } + this.docId = doc + readerContext.docBase; + } + + } + + } + + // AllGroupHeadsCollector optimized for ord fields. + static class OrdAllGroupHeadsCollector extends AllGroupHeadsCollector { + + private final String groupField; + private final SentinelIntSet ordSet; + private final List collectedGroups; + private final int[] reversed; + private final SortField[] fields; + private final int compIDXEnd; + + private final BytesRef spareBytesRef = new BytesRef(); + private FieldCache.DocTermsIndex groupIndex; + private FieldCache.DocTermsIndex[] sortsIndex; + private GroupHead[] segmentGroupHeads; + private IndexReader.AtomicReaderContext readerContext; + + OrdAllGroupHeadsCollector(String groupField, Sort sortWithinGroup, int initialSize) { + this.groupField = groupField; + ordSet = new SentinelIntSet(initialSize, -1); + collectedGroups = new ArrayList(initialSize); + + final SortField[] sortFields = sortWithinGroup.getSort(); + compIDXEnd = sortFields.length - 1; + reversed = new int[sortFields.length]; + fields = new SortField[sortFields.length]; + sortsIndex = new FieldCache.DocTermsIndex[sortFields.length]; + for (int i = 0; i < sortFields.length; i++) { + reversed[i] = sortFields[i].getReverse() ? -1 : 1; + fields[i] = sortFields[i]; + } + } + + public OpenBitSet retrieveAllGroupHeads(int maxDoc) { + OpenBitSet bitSet = new OpenBitSet(maxDoc); + + for (GroupHead groupHead : collectedGroups) { + bitSet.fastSet(groupHead.docId); + } + + return bitSet; + } + + public int[] retrieveAllGroupHeads() { + int[] groupHeads = new int[collectedGroups.size()]; + + int i = 0; + for (GroupHead collectedGroup : collectedGroups) { + groupHeads[i++] = collectedGroup.docId; + } + + return groupHeads; + } + + public void setScorer(Scorer scorer) throws IOException { + } + + public void collect(int doc) throws IOException { + int key = groupIndex.getOrd(doc); + if (!ordSet.exists(key)) { + ordSet.put(key); + BytesRef term = key == 0 ? null : groupIndex.getTerm(doc, new BytesRef()); + GroupHead groupHead = new GroupHead(doc, term); + collectedGroups.add(groupHead); + segmentGroupHeads[key] = groupHead; + return; + } + + // Ok now we need to check if the current doc is more relevant then current doc for this group + GroupHead groupHead = segmentGroupHeads[key]; + for (int compIDX = 0; ; compIDX++) { + final int c = reversed[compIDX] * groupHead.sortOrds[compIDX] - sortsIndex[compIDX].getOrd(doc); + if (c < 0) { + // Definitely not competitive. So don't even bother to continue + return; + } else if (c > 0) { + // Definitely competitive. + break; + } else if (compIDX == compIDXEnd) { + // Here c=0. If we're at the last comparator, this doc is not + // competitive, since docs are visited in doc Id order, which means + // this doc cannot compete with any other document in the queue. + return; + } + } + + groupHead.updateHead(doc); + } + + public void setNextReader(IndexReader.AtomicReaderContext context) throws IOException { + this.readerContext = context; + groupIndex = FieldCache.DEFAULT.getTermsIndex(context.reader, groupField); + for (int i = 0; i < fields.length; i++) { + sortsIndex[i] = FieldCache.DEFAULT.getTermsIndex(context.reader, fields[i].getField()); + } + + // Clear ordSet and fill it with previous encountered groups that can occur in the current segment. + ordSet.clear(); + segmentGroupHeads = new GroupHead[groupIndex.numOrd()]; + for (GroupHead collectedGroup : collectedGroups) { + int ord = groupIndex.binarySearchLookup(collectedGroup.groupValue, spareBytesRef); + if (ord >= 0) { + ordSet.put(ord); + segmentGroupHeads[ord] = collectedGroup; + + for (int i = 0; i < sortsIndex.length; i++) { + if (fields[i].getType() == SortField.SCORE) { + continue; + } + + collectedGroup.sortOrds[i] = sortsIndex[i].binarySearchLookup(collectedGroup.sortValues[i], spareBytesRef); + } + } + } + } + + public boolean acceptsDocsOutOfOrder() { + return false; + } + + private class GroupHead { + + BytesRef groupValue; + BytesRef[] sortValues; + int[] sortOrds; + int docId; + + private GroupHead(int doc, BytesRef groupValue) throws IOException { + this.groupValue = groupValue; + sortValues = new BytesRef[sortsIndex.length]; + sortOrds = new int[sortsIndex.length]; + for (int i = 0; i < sortsIndex.length; i++) { + sortValues[i] = sortsIndex[i].getTerm(doc, new BytesRef()); + sortOrds[i] = sortsIndex[i].getOrd(doc); + } + + this.docId = doc + readerContext.docBase; + } + + private void updateHead(int doc) throws IOException { + for (int i = 0; i < sortsIndex.length; i++) { + sortValues[i] = sortsIndex[i].getTerm(doc, new BytesRef()); + sortOrds[i] = sortsIndex[i].getOrd(doc); + } + this.docId = doc + readerContext.docBase; + } + + } + + } + + // AllGroupHeadsCollector optimized for scores. + static class ScoreAllGroupHeadsCollector extends AllGroupHeadsCollector { + + private final String groupField; + private final SentinelIntSet ordSet; + private final List collectedGroups; + private final int[] reversed; + private final SortField[] fields; + private final int compIDXEnd; + + private final BytesRef spareBytesRef = new BytesRef(); + private FieldCache.DocTermsIndex groupIndex; + private Scorer scorer; + private GroupHead[] segmentGroupHeads; + private IndexReader.AtomicReaderContext readerContext; + + ScoreAllGroupHeadsCollector(String groupField, Sort sortWithinGroup, int initialSize) { + this.groupField = groupField; + ordSet = new SentinelIntSet(initialSize, -1); + collectedGroups = new ArrayList(initialSize); + + final SortField[] sortFields = sortWithinGroup.getSort(); + compIDXEnd = sortFields.length - 1; + reversed = new int[sortFields.length]; + fields = new SortField[sortFields.length]; + for (int i = 0; i < sortFields.length; i++) { + reversed[i] = sortFields[i].getReverse() ? -1 : 1; + fields[i] = sortFields[i]; + } + } + + public OpenBitSet retrieveAllGroupHeads(int maxDoc) { + OpenBitSet bitSet = new OpenBitSet(maxDoc); + + for (GroupHead groupHead : collectedGroups) { + bitSet.fastSet(groupHead.docId); + } + + return bitSet; + } + + public int[] retrieveAllGroupHeads() { + int[] groupHeads = new int[collectedGroups.size()]; + + int i = 0; + for (GroupHead collectedGroup : collectedGroups) { + groupHeads[i++] = collectedGroup.docId; + } + + return groupHeads; + } + + public void setScorer(Scorer scorer) throws IOException { + this.scorer = scorer; + } + + public void collect(int doc) throws IOException { + int key = groupIndex.getOrd(doc); + if (!ordSet.exists(key)) { + ordSet.put(key); + BytesRef term = key == 0 ? null : groupIndex.getTerm(doc, new BytesRef()); + GroupHead groupHead = new GroupHead(doc, term); + collectedGroups.add(groupHead); + segmentGroupHeads[key] = groupHead; + return; + } + + // Ok now we need to check if the current doc is more relevant then current doc for this group + GroupHead groupHead = segmentGroupHeads[key]; + float score = scorer.score(); + for (int compIDX = 0; ; compIDX++) { + final int c = reversed[compIDX] * compare(groupHead, compIDX, score); + if (c < 0) { + // Definitely not competitive. So don't even bother to continue + return; + } else if (c > 0) { + // Definitely competitive. + break; + } else if (compIDX == compIDXEnd) { + // Here c=0. If we're at the last comparator, this doc is not + // competitive, since docs are visited in doc Id order, which means + // this doc cannot compete with any other document in the queue. + return; + } + } + + groupHead.updateHead(doc, score); + } + + private int compare(GroupHead groupHead, int compIDX, float score) throws IOException { + if (groupHead.scores[compIDX] < score) { + return -1; + } else if (groupHead.scores[compIDX] > score) { + return 1; + } + return 0; + } + + public void setNextReader(IndexReader.AtomicReaderContext context) throws IOException { + this.readerContext = context; + groupIndex = FieldCache.DEFAULT.getTermsIndex(context.reader, groupField); + + // Clear ordSet and fill it with previous encountered groups that can occur in the current segment. + ordSet.clear(); + segmentGroupHeads = new GroupHead[groupIndex.numOrd()]; + for (GroupHead collectedGroup : collectedGroups) { + int ord = groupIndex.binarySearchLookup(collectedGroup.groupValue, spareBytesRef); + if (ord >= 0) { + ordSet.put(ord); + segmentGroupHeads[ord] = collectedGroup; + } + } + } + + public boolean acceptsDocsOutOfOrder() { + return false; + } + + private class GroupHead { + + BytesRef groupValue; + float[] scores; + int docId; + + private GroupHead(int doc, BytesRef groupValue) throws IOException { + this.groupValue = groupValue; + scores = new float[fields.length]; + float score = scorer.score(); + for (int i = 0; i < scores.length; i++) { + scores[i] = score; + } + + this.docId = doc + readerContext.docBase; + } + + private void updateHead(int doc, float score) throws IOException { + for (int i = 0; i < scores.length; i++) { + scores[i] = score; + } + this.docId = doc + readerContext.docBase; + } + + } + + } + +} \ No newline at end of file