Index: modules/grouping/src/test/org/apache/lucene/search/grouping/TermGroupFacetCollectorTest.java IDEA additional info: Subsystem: com.intellij.openapi.diff.impl.patch.CharsetEP <+>UTF-8 =================================================================== --- modules/grouping/src/test/org/apache/lucene/search/grouping/TermGroupFacetCollectorTest.java (revision ) +++ modules/grouping/src/test/org/apache/lucene/search/grouping/TermGroupFacetCollectorTest.java (revision ) @@ -0,0 +1,228 @@ +package org.apache.lucene.search.grouping; + +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import org.apache.lucene.analysis.MockAnalyzer; +import org.apache.lucene.document.*; +import org.apache.lucene.index.DocValues; +import org.apache.lucene.index.RandomIndexWriter; +import org.apache.lucene.search.IndexSearcher; +import org.apache.lucene.search.MatchAllDocsQuery; +import org.apache.lucene.search.grouping.term.TermGroupFacetCollector; +import org.apache.lucene.store.Directory; +import org.apache.lucene.util.BytesRef; +import org.apache.lucene.util.LuceneTestCase; + +import java.util.List; + +/** + * @author Martijn van Groningen + */ +public class TermGroupFacetCollectorTest extends LuceneTestCase { + + public void testSimple() throws Exception { + final String groupField = "hotel"; + FieldType customType = new FieldType(); + customType.setStored(true); + + Directory dir = newDirectory(); + RandomIndexWriter w = new RandomIndexWriter( + random, + dir, + newIndexWriterConfig(TEST_VERSION_CURRENT, + new MockAnalyzer(random)).setMergePolicy(newLogMergePolicy())); + boolean canUseIDV = false;// Enable later... !"Lucene3x".equals(w.w.getConfig().getCodec().getName()); + + // 0 + Document doc = new Document(); + addGroupField(doc, groupField, "a", canUseIDV); + doc.add(new Field("airport", "ams", TextField.TYPE_UNSTORED)); + doc.add(new Field("duration", "5", TextField.TYPE_UNSTORED)); + w.addDocument(doc); + + // 1 + doc = new Document(); + addGroupField(doc, groupField, "a", canUseIDV); + doc.add(new Field("airport", "dus", TextField.TYPE_STORED)); + doc.add(new Field("duration", "10", TextField.TYPE_UNSTORED)); + w.addDocument(doc); + + // 2 + doc = new Document(); + addGroupField(doc, groupField, "b", canUseIDV); + doc.add(new Field("airport", "ams", TextField.TYPE_UNSTORED)); + doc.add(new Field("duration", "10", TextField.TYPE_UNSTORED)); + w.addDocument(doc); + w.commit(); // To ensure a second segment + + // 3 + doc = new Document(); + addGroupField(doc, groupField, "b", canUseIDV); + doc.add(new Field("airport", "ams", TextField.TYPE_UNSTORED)); + doc.add(new Field("duration", "5", TextField.TYPE_UNSTORED)); + w.addDocument(doc); + + // 4 + doc = new Document(); + addGroupField(doc, groupField, "b", canUseIDV); + doc.add(new Field("airport", "ams", TextField.TYPE_UNSTORED)); + doc.add(new Field("duration", "5", TextField.TYPE_UNSTORED)); + w.addDocument(doc); + + IndexSearcher indexSearcher = new IndexSearcher(w.getReader()); + + + TermGroupFacetCollector groupedAirportFacetCollector = new TermGroupFacetCollector(groupField, "airport", null, 128); + indexSearcher.search(new MatchAllDocsQuery(), groupedAirportFacetCollector); + TermGroupFacetCollector.GroupedFacetResult airportResult = groupedAirportFacetCollector.mergeSegmentResults(10, 0, false); + assertEquals(3, airportResult.getTotalCount()); + assertEquals(0, airportResult.getTotalMissingCount()); + + List entries = airportResult.getFacetEntries(0, 10); + assertEquals(2, entries.size()); + assertEquals("ams", entries.get(0).getValue().utf8ToString()); + assertEquals(2, entries.get(0).getCount()); + assertEquals("dus", entries.get(1).getValue().utf8ToString()); + assertEquals(1, entries.get(1).getCount()); + + + TermGroupFacetCollector groupedDurationFacetCollector = new TermGroupFacetCollector(groupField, "duration", null, 128); + indexSearcher.search(new MatchAllDocsQuery(), groupedDurationFacetCollector); + TermGroupFacetCollector.GroupedFacetResult durationResult = groupedDurationFacetCollector.mergeSegmentResults(10, 0, false); + assertEquals(4, durationResult.getTotalCount()); + assertEquals(0, durationResult.getTotalMissingCount()); + + entries = durationResult.getFacetEntries(0, 10); + assertEquals(2, entries.size()); + assertEquals("10", entries.get(0).getValue().utf8ToString()); + assertEquals(2, entries.get(0).getCount()); + assertEquals("5", entries.get(1).getValue().utf8ToString()); + assertEquals(2, entries.get(1).getCount()); + + // 5 + doc = new Document(); + addGroupField(doc, groupField, "b", canUseIDV); + doc.add(new Field("duration", "5", TextField.TYPE_UNSTORED)); + w.addDocument(doc); + + // 6 + doc = new Document(); + addGroupField(doc, groupField, "b", canUseIDV); + doc.add(new Field("airport", "bru", TextField.TYPE_UNSTORED)); + doc.add(new Field("duration", "10", TextField.TYPE_UNSTORED)); + w.addDocument(doc); + + // 7 + doc = new Document(); + addGroupField(doc, groupField, "b", canUseIDV); + doc.add(new Field("airport", "bru", TextField.TYPE_UNSTORED)); + doc.add(new Field("duration", "15", TextField.TYPE_UNSTORED)); + w.addDocument(doc); + + // 8 + doc = new Document(); + addGroupField(doc, groupField, "a", canUseIDV); + doc.add(new Field("airport", "bru", TextField.TYPE_UNSTORED)); + doc.add(new Field("duration", "10", TextField.TYPE_UNSTORED)); + w.addDocument(doc); + + indexSearcher.getIndexReader().close(); + indexSearcher = new IndexSearcher(w.getReader()); + groupedAirportFacetCollector = new TermGroupFacetCollector(groupField, "airport", null, 128); + indexSearcher.search(new MatchAllDocsQuery(), groupedAirportFacetCollector); + airportResult = groupedAirportFacetCollector.mergeSegmentResults(3, 0, true); + assertEquals(5, airportResult.getTotalCount()); + assertEquals(1, airportResult.getTotalMissingCount()); + + entries = airportResult.getFacetEntries(1, 2); + assertEquals(2, entries.size()); + assertEquals("bru", entries.get(0).getValue().utf8ToString()); + assertEquals(2, entries.get(0).getCount()); + assertEquals("dus", entries.get(1).getValue().utf8ToString()); + assertEquals(1, entries.get(1).getCount()); + + groupedDurationFacetCollector = new TermGroupFacetCollector(groupField, "duration", null, 128); + indexSearcher.search(new MatchAllDocsQuery(), groupedDurationFacetCollector); + durationResult = groupedDurationFacetCollector.mergeSegmentResults(10, 2, true); + assertEquals(5, durationResult.getTotalCount()); + assertEquals(0, durationResult.getTotalMissingCount()); + + entries = durationResult.getFacetEntries(1, 1); + assertEquals(1, entries.size()); + assertEquals("5", entries.get(0).getValue().utf8ToString()); + assertEquals(2, entries.get(0).getCount()); + + // 9 + doc = new Document(); + addGroupField(doc, groupField, "c", canUseIDV); + doc.add(new Field("airport", "bru", TextField.TYPE_UNSTORED)); + doc.add(new Field("duration", "15", TextField.TYPE_UNSTORED)); + w.addDocument(doc); + + // 10 + doc = new Document(); + addGroupField(doc, groupField, "c", canUseIDV); + doc.add(new Field("airport", "dus", TextField.TYPE_UNSTORED)); + doc.add(new Field("duration", "10", TextField.TYPE_UNSTORED)); + w.addDocument(doc); + + indexSearcher.getIndexReader().close(); + indexSearcher = new IndexSearcher(w.getReader()); + groupedAirportFacetCollector = new TermGroupFacetCollector(groupField, "airport", null, 128); + indexSearcher.search(new MatchAllDocsQuery(), groupedAirportFacetCollector); + airportResult = groupedAirportFacetCollector.mergeSegmentResults(10, 0, false); + assertEquals(7, airportResult.getTotalCount()); + assertEquals(1, airportResult.getTotalMissingCount()); + + entries = airportResult.getFacetEntries(0, 10); + assertEquals(3, entries.size()); + assertEquals("ams", entries.get(0).getValue().utf8ToString()); + assertEquals(2, entries.get(0).getCount()); + assertEquals("bru", entries.get(1).getValue().utf8ToString()); + assertEquals(3, entries.get(1).getCount()); + assertEquals("dus", entries.get(2).getValue().utf8ToString()); + assertEquals(2, entries.get(2).getCount()); + + groupedDurationFacetCollector = new TermGroupFacetCollector(groupField, "duration", new BytesRef("1"), 128); + indexSearcher.search(new MatchAllDocsQuery(), groupedDurationFacetCollector); + durationResult = groupedDurationFacetCollector.mergeSegmentResults(10, 0, true); + assertEquals(5, durationResult.getTotalCount()); + assertEquals(0, durationResult.getTotalMissingCount()); + + entries = durationResult.getFacetEntries(0, 10); + assertEquals(2, entries.size()); + assertEquals("10", entries.get(0).getValue().utf8ToString()); + assertEquals(3, entries.get(0).getCount()); + assertEquals("15", entries.get(1).getValue().utf8ToString()); + assertEquals(2, entries.get(1).getCount()); +// assertEquals("5", entries.get(2).getValue().utf8ToString()); +// assertEquals(2, entries.get(2).getCount()); + + w.close(); + indexSearcher.getIndexReader().close(); + dir.close(); + } + + private void addGroupField(Document doc, String groupField, String value, boolean canUseIDV) { + doc.add(new Field(groupField, value, TextField.TYPE_UNSTORED)); + if (canUseIDV) { + doc.add(new DocValuesField(groupField, new BytesRef(value), DocValues.Type.BYTES_VAR_SORTED)); + } + } + +} Index: modules/grouping/src/java/org/apache/lucene/search/grouping/term/TermGroupFacetCollector.java IDEA additional info: Subsystem: com.intellij.openapi.diff.impl.patch.CharsetEP <+>UTF-8 =================================================================== --- modules/grouping/src/java/org/apache/lucene/search/grouping/term/TermGroupFacetCollector.java (revision ) +++ modules/grouping/src/java/org/apache/lucene/search/grouping/term/TermGroupFacetCollector.java (revision ) @@ -0,0 +1,351 @@ +package org.apache.lucene.search.grouping.term; + +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import org.apache.lucene.index.AtomicReaderContext; +import org.apache.lucene.index.TermsEnum; +import org.apache.lucene.search.Collector; +import org.apache.lucene.search.FieldCache; +import org.apache.lucene.search.Scorer; +import org.apache.lucene.util.BytesRef; +import org.apache.lucene.util.PriorityQueue; +import org.apache.lucene.util.SentinelIntSet; +import org.apache.lucene.util.UnicodeUtil; + +import java.io.IOException; +import java.util.*; + +/** + * Term based grouped facets collector. + * + * @lucene.experimental + */ +public class TermGroupFacetCollector extends Collector { + + private final String groupField; + private final String facetField; + private final BytesRef facetPrefix; + + private final List groupedFacetHits; + private final SentinelIntSet segmentGroupedFacetHits; + private final List segmentResults; + private final BytesRef spare = new BytesRef(); + + private FieldCache.DocTermsIndex groupFieldTermsIndex; + private FieldCache.DocTermsIndex facetFieldTermsIndex; + private int[] segmentFacetCounts; + private int segmentTotalCount; + + private int startFacetOrd; + private int endFacetOrd; + + public TermGroupFacetCollector(String groupField, String facetField, BytesRef facetPrefix, int initialSize) { + this.groupField = groupField; + this.facetField = facetField; + this.facetPrefix = facetPrefix; + groupedFacetHits = new ArrayList(initialSize); + segmentGroupedFacetHits = new SentinelIntSet(initialSize, -1); + segmentResults = new ArrayList(); + } + + public GroupedFacetResult mergeSegmentResults(int size, int minCount, boolean orderByCount) throws IOException { + if (segmentTotalCount > 0) { + segmentResults.add(new SegmentResult(segmentFacetCounts, segmentTotalCount, facetFieldTermsIndex, startFacetOrd, endFacetOrd)); + segmentTotalCount = 0; // reset + } + + int totalCount = 0; + int missingCount = 0; + SegmentResultPriorityQueue segments = new SegmentResultPriorityQueue(segmentResults.size()); + for (SegmentResult segmentResult : segmentResults) { + totalCount += segmentResult.total; + missingCount += segmentResult.missing; + segmentResult.initializeForMerge(); + segments.add(segmentResult); + } + GroupedFacetResult facetResult = new GroupedFacetResult(size, minCount, orderByCount, missingCount, totalCount); + + while (segments.size() > 0) { + SegmentResult segmentResult = segments.top(); + BytesRef currentFacetValue = BytesRef.deepCopyOf(segmentResult.mergeTerm); + int count = 0; + + do { + count += segmentResult.counts[segmentResult.mergePos++]; + if (segmentResult.mergePos < segmentResult.maxTermPos) { + segmentResult.nextTerm(); + segmentResult = segments.updateTop(); + } else { + segments.pop(); + segmentResult = segments.top(); + if (segmentResult == null) { + break; + } + } + } while (currentFacetValue.equals(segmentResult.mergeTerm)); + +// System.out.println("Add entry " + currentFacetValue.utf8ToString() + " with count " + count); + facetResult.addFacetCount(currentFacetValue, count); + } + + return facetResult; + } + + public void collect(int doc) throws IOException { + int facetOrd = facetFieldTermsIndex.getOrd(doc); + if (facetOrd < startFacetOrd || facetOrd >= endFacetOrd) { + return; + } + + int groupOrd = groupFieldTermsIndex.getOrd(doc); + int segmentGroupedFacetsIndex = (groupOrd * facetFieldTermsIndex.numOrd()) + facetOrd; + if (segmentGroupedFacetHits.exists(segmentGroupedFacetsIndex)) { + return; + } + + segmentTotalCount++; + segmentFacetCounts[facetOrd]++; + + segmentGroupedFacetHits.put(segmentGroupedFacetsIndex); + groupedFacetHits.add( + new GroupedFacetHit( + groupFieldTermsIndex.lookup(groupOrd, new BytesRef()), + facetFieldTermsIndex.lookup(facetOrd, new BytesRef()) + ) + ); + } + + public void setNextReader(AtomicReaderContext context) throws IOException { + if (segmentTotalCount > 0) { + segmentResults.add(new SegmentResult(segmentFacetCounts, segmentTotalCount, facetFieldTermsIndex, startFacetOrd, endFacetOrd)); + } + + groupFieldTermsIndex = FieldCache.DEFAULT.getTermsIndex(context.reader(), groupField); + facetFieldTermsIndex = FieldCache.DEFAULT.getTermsIndex(context.reader(), facetField); + segmentFacetCounts = new int[facetFieldTermsIndex.numOrd()]; + segmentTotalCount = 0; + + segmentGroupedFacetHits.clear(); + for (GroupedFacetHit groupedFacetHit : groupedFacetHits) { + int facetOrdinal = facetFieldTermsIndex.binarySearchLookup(groupedFacetHit.facetValue, spare); + if (facetOrdinal < 0) { + continue; + } + + int groupOrdinal = groupFieldTermsIndex.binarySearchLookup(groupedFacetHit.groupValue, spare); + if (groupOrdinal < 0) { + continue; + } + + int checkerIndex = (groupOrdinal * facetFieldTermsIndex.numOrd()) + facetOrdinal; + segmentGroupedFacetHits.put(checkerIndex); + } + + if (facetPrefix != null) { + startFacetOrd = facetFieldTermsIndex.binarySearchLookup(facetPrefix, spare); + if (startFacetOrd < 0) { + startFacetOrd = -startFacetOrd - 1; // + } + BytesRef facetEndPrefix = BytesRef.deepCopyOf(facetPrefix); + facetEndPrefix.append(UnicodeUtil.BIG_TERM); + endFacetOrd = facetFieldTermsIndex.binarySearchLookup(facetEndPrefix, spare); + endFacetOrd = -endFacetOrd - 1; // Recalculates the ord after facetEndPrefix. + } else { + startFacetOrd = 0; + endFacetOrd = facetFieldTermsIndex.numOrd(); + } + } + + public void setScorer(Scorer scorer) throws IOException { + } + + public boolean acceptsDocsOutOfOrder() { + return true; + } + + public static class GroupedFacetResult { + + private final static Comparator orderByCountAndValue = new Comparator() { + + public int compare(FacetEntry a, FacetEntry b) { + int cmp = b.count - a.count; // Highest count first! + if (cmp != 0) { + return cmp; + } + return a.value.compareTo(b.value); + } + + }; + + private final static Comparator orderByValue = new Comparator() { + + public int compare(FacetEntry a, FacetEntry b) { + return a.value.compareTo(b.value); + } + + }; + + private final int maxSize; + private final NavigableSet facetEntries; + private final int totalCount; + private final int totalMissingCount; + + private int currentMin; + + public GroupedFacetResult(int size, int minCount, boolean orderByCount, int totalMissingCount, int totalCount) { + this.totalMissingCount = totalMissingCount; + this.totalCount = totalCount; + this.facetEntries = new TreeSet(orderByCount ? orderByCountAndValue : orderByValue); + + maxSize = size; + currentMin = minCount; + } + + public void addFacetCount(BytesRef facetValue, int count) { + if (count < currentMin) { + return; + } + + if (facetEntries.size() == maxSize) { + facetEntries.pollLast(); + } + + facetEntries.add(new FacetEntry(facetValue, count)); + if (facetEntries.size() == maxSize) { + currentMin = facetEntries.last().count; + } + } + + public List getFacetEntries(int offset, int limit) { + List entries = new LinkedList(); + limit += offset; + + int i = 0; + for (FacetEntry facetEntry : facetEntries) { + if (i < offset) { + i++; + continue; + } + if (i++ >= limit) { + break; + } + entries.add(facetEntry); + } + return entries; + } + + public int getTotalCount() { + return totalCount; + } + + public int getTotalMissingCount() { + return totalMissingCount; + } + } + + public static class FacetEntry { + + private final BytesRef value; + private final int count; + + FacetEntry(BytesRef value, int count) { + this.value = value; + this.count = count; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + + FacetEntry that = (FacetEntry) o; + return value.equals(that.value); + } + + @Override + public int hashCode() { + return value.hashCode(); + } + + public BytesRef getValue() { + return value; + } + + public int getCount() { + return count; + } + } + +} + +class SegmentResult { + + final int[] counts; + final int total; + final int missing; + final FieldCache.DocTermsIndex facetFieldTermsIndex; + + // Used for merging the segment results + BytesRef mergeTerm; + int mergePos; + final int maxTermPos; + TermsEnum tenum; + + SegmentResult(int[] counts, int total, FieldCache.DocTermsIndex facetFieldTermsIndex, int startFacetOrd, int endFacetOrd) { + this.counts = counts; + this.missing = counts[0]; + this.total = total - missing; + this.facetFieldTermsIndex = facetFieldTermsIndex; + this.mergePos = startFacetOrd; + this.maxTermPos = endFacetOrd; + } + + void initializeForMerge() throws IOException { + tenum = facetFieldTermsIndex.getTermsEnum(); + mergePos = 1; + tenum.seekExact(mergePos); + mergeTerm = tenum.term(); + } + + void nextTerm() throws IOException { + mergeTerm = tenum.next(); + } + +} + +class GroupedFacetHit { + + final BytesRef groupValue; + final BytesRef facetValue; + + GroupedFacetHit(BytesRef groupValue, BytesRef facetValue) { + this.groupValue = groupValue; + this.facetValue = facetValue; + } +} + +class SegmentResultPriorityQueue extends PriorityQueue { + + SegmentResultPriorityQueue(int maxSize) { + super(maxSize); + } + + protected boolean lessThan(SegmentResult a, SegmentResult b) { + return a.mergeTerm.compareTo(b.mergeTerm) < 0; + } +}