diff --git a/lucene/core/src/java/org/apache/lucene/search/BooleanNovelScorer.java b/lucene/core/src/java/org/apache/lucene/search/BooleanNovelScorer.java new file mode 100644 index 0000000..4989951 --- /dev/null +++ b/lucene/core/src/java/org/apache/lucene/search/BooleanNovelScorer.java @@ -0,0 +1,384 @@ +package org.apache.lucene.search; + +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import java.io.IOException; +import java.util.List; + +import org.apache.lucene.search.BooleanQuery.BooleanWeight; + +/** + * This is an improvement of {@link BooleanScorer}. + * It only supports cases where there is at least one MUST clause. + */ +final class BooleanNovelScorer extends Scorer { + + private static final class Bucket { + int doc; // doc id + // score is divided into RS and OS, so that its calculating order + // can be the same as the DAAT procedure. + double requiredScore; // incremental required score + double optionalScore; // incremental optional score + int coord; // count of terms in score + boolean valid; + + void set(Bucket bucket) { + this.doc = bucket.doc; + this.requiredScore = bucket.requiredScore; + this.optionalScore = bucket.optionalScore; + this.coord = bucket.coord; + this.valid = bucket.valid; + } + } + + /** A simple hash table of document scores within a range. */ + private final class BucketTable { + static final int SIZE = 1 << 11; + + private final Bucket[] buckets = new Bucket[SIZE]; + // After collecting more documents, if there are more documents not collected. + boolean more = true; + private int numOfBuckets = 0; + + BucketTable() { + // Pre-fill to save the lazy init when collecting + // each sub: + for(int idx=0;idxmore is true, while numOfBuckets is 0. + * @return true if more matching documents may remain. + */ + boolean collectMore() throws IOException { + numOfBuckets = 0; + + // Scan requiredDocs full fill the bucket table + while (numOfBuckets < SIZE) { + final int requiredDocID = requiredConjunctionScorer.nextDoc(); + if (requiredDocID == DocIdSetIterator.NO_MORE_DOCS) { + more = false; + break; + } + final Bucket bucket = buckets[numOfBuckets]; + bucket.doc = requiredDocID; + bucket.coord = requiredScorers.size(); + bucket.requiredScore = requiredConjunctionScorer.score(); + bucket.optionalScore = 0; + numOfBuckets ++; + } + + // Scan prohibitedDocs to remove docs from bucket table. + for (Scorer prohibitedScorer : prohibitedScorers) { + int i = 0; + int newNumOfBuckets = 0; + while (i < numOfBuckets) { + final Bucket bucket = buckets[i]; + int prohibitedDocID = prohibitedScorer.docID(); + if (prohibitedDocID < bucket.doc) { + // According to the definition of .advance(), + // prohibitedDocID should be less then bucket.doc, + // before calling .advance(bucket.doc); + + // Skip to bucket.doc, so that prohibitedDocID >= bucket.doc + prohibitedDocID = prohibitedScorer.advance(bucket.doc); + } + + if (prohibitedDocID == DocIdSetIterator.NO_MORE_DOCS) { + if (i > newNumOfBuckets) { + while (i < numOfBuckets) { + buckets[newNumOfBuckets ++].set(buckets[i ++]); + } + } else { + newNumOfBuckets = numOfBuckets; + } + break; + + } else if (prohibitedDocID == bucket.doc) { + // remove the prohibited bucket. + i ++; + + } else { // prohibitedDocID > bucket.doc + if (i > newNumOfBuckets) { + while (i < numOfBuckets && prohibitedDocID > buckets[i].doc) { + buckets[newNumOfBuckets ++].set(buckets[i ++]); + } + } else { + i = skipsTo(i, prohibitedDocID); + newNumOfBuckets = i; + } + } + } + numOfBuckets = newNumOfBuckets; + } + + // Scan optionalDocs to add coord and score. + // TODO: use countLeft to judge whether a bucket should be removed. +// int countLeft = optionalScorers.size(); + for (Scorer optionalScorer : optionalScorers) { +// countLeft --; + int i = 0; + while (i < numOfBuckets) { + final Bucket bucket = buckets[i]; + int optionalDocID = optionalScorer.docID(); + if (optionalDocID < bucket.doc) { + // According to the definition of .advance(), + // optionalDocID should be less then bucket.doc, + // before calling .advance(bucket.doc); + + // Skip to bucket.doc, so that optionalDocID >= bucket.doc + optionalDocID = optionalScorer.advance(bucket.doc); + } + + if (optionalDocID == DocIdSetIterator.NO_MORE_DOCS) { + break; + + } else if (optionalDocID == bucket.doc) { + bucket.coord ++; + bucket.optionalScore += optionalScorer.score(); + i ++; +// if (oldBucket.coord + countLeft < minNrShouldMatch) remove(oldBucket); + + } else { // optionalDocID > bucket.doc + // current bucket skips to optionalDocID. + i = skipsTo(i, optionalDocID); +// while (i < numOfBuckets && optionalDocID > buckets[i].doc) { +// i ++; +// } + } + } + } + + if (more && numOfBuckets == 0) { + // If there are more docs not collected, but no doc is collected in this iteration, + // collect more again. + return collectMore(); + } + return more; + } + + + /** + * Skips to target from begin.
+ * NOTE: Undefined when buckets[i].doc >= target. + * @param begin the index begin to skip, MUST s.t. -1 ≤ begin. + * @param target the target doc to skip. + * @return the first index i s.t. buckets[i].doc >= target. + */ + int skipsTo(int begin, int target) { + final int DELTA = 16; + int i = begin + DELTA; + while (i < numOfBuckets && buckets[i].doc < target) { + i += DELTA; + } + + int end = i; + if (end >= numOfBuckets) { + end = numOfBuckets; + + } else if (buckets[end].doc == target) { + return end; + } + + i = i - DELTA + 1; + while (i < end && buckets[i].doc < target) { + i ++; + } + return i; + } + + void advance(int target) throws IOException { + // advance requiredConjunctionScorer to target doc + if (requiredConjunctionScorer.docID() < target) { + requiredConjunctionScorer.advance(target); + } + + // Scan prohibitedDocs to advance to target doc + for (Scorer prohibitedScorer : prohibitedScorers) { + if (prohibitedScorer.docID() < target) { + prohibitedScorer.advance(target); + } + } + + // Scan optionalDocs to advance to target doc + for (Scorer optionalScorer : optionalScorers) { + if (optionalScorer.docID() < target) { + optionalScorer.advance(target); + } + } + } + } + + private final BucketTable bucketTable = new BucketTable(); + private final float[] coordFactors; + private int currentIndex = -1; + private int currentDoc = -1; + + final private Scorer requiredConjunctionScorer; + final private List requiredScorers; + final private List optionalScorers; + final private List prohibitedScorers; + // minNrShouldMatch only applies to SHOULD clauses + final private int minNrShouldMatch; + + BooleanNovelScorer(BooleanWeight weight, boolean disableCoord, int minNrShouldMatch, + List requiredScorers, List optionalScorers, List prohibitedScorers, + int maxCoord) throws IOException { + super(weight); + if (requiredScorers.size() == 0) { + throw new IllegalArgumentException("requriedScorers.size() must be > 0"); + } + if (minNrShouldMatch < 0) { + throw new IllegalArgumentException("Minimum number of optional scorers should not be negative"); + } + this.minNrShouldMatch = minNrShouldMatch; + + this.requiredScorers = requiredScorers; + this.optionalScorers = optionalScorers; + this.prohibitedScorers = prohibitedScorers; + this.requiredConjunctionScorer = new ConjunctionScorer( + this.weight, this.requiredScorers.toArray(new Scorer[this.requiredScorers.size()])); + + coordFactors = new float[requiredScorers.size() + optionalScorers.size() + 1]; + for (int i = 0; i < coordFactors.length; i++) { + coordFactors[i] = disableCoord ? 1.0f : weight.coord(i, maxCoord); + } + } + + @Override + public String toString() { + StringBuffer buffer = new StringBuffer(); + buffer.append("boolean("); + for (Scorer requiredScorer : requiredScorers) { + buffer.append("+"); + buffer.append(requiredScorer.toString()); + buffer.append(" "); + } + for (Scorer optionalScorer : optionalScorers) { + buffer.append(optionalScorer.toString()); + buffer.append(" "); + } + for (Scorer prohibitedScorer : prohibitedScorers) { + buffer.append("-"); + buffer.append(prohibitedScorer.toString()); + buffer.append(" "); + } + // Delete the last whitespace + if (buffer.length() > "boolean(".length()) { + buffer.deleteCharAt(buffer.length() - 1); + } + buffer.append(")"); + return buffer.toString(); + } + + @Override + public float score() throws IOException { + final Bucket bucket = bucketTable.buckets[currentIndex]; + // Cast required score and optional score to float, + // in order to make the calculating procedure is the same as DAAT. + return (float) ((float) bucket.requiredScore + (float) bucket.optionalScore) * coordFactors[bucket.coord]; + } + + @Override + public int freq() throws IOException { + return bucketTable.buckets[currentIndex].coord; + } + + @Override + public int docID() { + return currentDoc; + } + + @Override + public int nextDoc() throws IOException { + currentIndex ++; + if (bucketTable.more && currentIndex >= bucketTable.numOfBuckets) { + bucketTable.collectMore(); + currentIndex = 0; + } + + while (currentIndex < bucketTable.numOfBuckets) { + final Bucket bucket = bucketTable.buckets[currentIndex]; + + if (bucket.coord - requiredScorers.size() >= minNrShouldMatch) { + return currentDoc = bucket.doc; + } + + currentIndex ++; + if (bucketTable.more && currentIndex >= bucketTable.numOfBuckets) { + bucketTable.collectMore(); + currentIndex = 0; + } + } + return currentDoc = DocIdSetIterator.NO_MORE_DOCS; + } + + @Override + public int advance(int target) throws IOException { + return slowAdvance(target); +// if (target == currentDoc) { +// return currentDoc; +// } +// +// if (bucketTable.numOfBuckets == 0 && bucketTable.more) { +// // The beginning +// bucketTable.advance(target); +// bucketTable.collectMore(); +// currentIndex = -1; +// return nextDoc(); +// } +// +// if (bucketTable.numOfBuckets == 0) { +// return currentDoc = DocIdSetIterator.NO_MORE_DOCS; +// } +// +// final Bucket last = bucketTable.buckets[bucketTable.numOfBuckets - 1]; +// if (target > last.doc) { +// bucketTable.advance(target); +// bucketTable.collectMore(); +// currentIndex = -1; +// return nextDoc(); +// } +// +// // target <= last.doc +// currentIndex = bucketTable.skipsTo(currentIndex, target); +// if (currentIndex < bucketTable.numOfBuckets) { +// final Bucket bucket = bucketTable.buckets[currentIndex]; +// if (bucket.coord - requiredScorers.size() >= minNrShouldMatch) { +// return currentDoc = bucket.doc; +// } +// return nextDoc(); +// } +// +// // currentIndex >= bucketTable.numOfBuckets +// if (bucketTable.more) { +// bucketTable.collectMore(); +// currentIndex = -1; +// return nextDoc(); +// } +// return currentDoc = DocIdSetIterator.NO_MORE_DOCS; + } + + @Override + public long cost() { + return requiredConjunctionScorer.cost(); + } +} diff --git a/lucene/core/src/java/org/apache/lucene/search/BooleanQuery.java b/lucene/core/src/java/org/apache/lucene/search/BooleanQuery.java index 4d7635d..dcb8a7f 100644 --- a/lucene/core/src/java/org/apache/lucene/search/BooleanQuery.java +++ b/lucene/core/src/java/org/apache/lucene/search/BooleanQuery.java @@ -314,6 +314,7 @@ public class BooleanQuery extends Query implements Iterable { return super.bulkScorer(context, scoreDocsInOrder, acceptDocs); } + List required = new ArrayList(); List prohibited = new ArrayList(); List optional = new ArrayList(); Iterator cIter = clauses.iterator(); @@ -329,6 +330,15 @@ public class BooleanQuery extends Query implements Iterable { // would handle conjunctions faster than // BooleanScorer2... return super.bulkScorer(context, scoreDocsInOrder, acceptDocs); + +// Scorer requiredSubScorer = w.scorer(context, acceptDocs); +// // if no doc matches required, then return null to say +// // no doc matches this Boolean Query. +// if (requiredSubScorer == null) { +// return null; +// } +// required.add(requiredSubScorer); + } else if (c.isProhibited()) { prohibited.add(subScorer); } else { @@ -336,7 +346,7 @@ public class BooleanQuery extends Query implements Iterable { } } - return new BooleanScorer(this, disableCoord, minNrShouldMatch, optional, prohibited, maxCoord); + return new BooleanScorer(this, disableCoord, minNrShouldMatch, required, optional, prohibited, maxCoord); } @Override @@ -385,6 +395,12 @@ public class BooleanQuery extends Query implements Iterable { // no documents will be matched by the query return null; } + + // TODO: Find a better trigger condition on calling BooleanNovalScorer + if (required.size() > 0 && (optional.size() > 1 || prohibited.size() > 1)) { + return new BooleanNovelScorer(this, disableCoord, minNrShouldMatch, required, optional, prohibited, maxCoord); +// return new BooleanScorerInOrder(this, disableCoord, minNrShouldMatch, required, optional, prohibited, maxCoord); + } // three cases: conjunction, disjunction, or mix @@ -406,7 +422,7 @@ public class BooleanQuery extends Query implements Iterable { Scorer req = excl(req(required, true), prohibited); Scorer opt = opt(optional, minShouldMatch, true); - + // TODO: clean this up: its horrible if (disableCoord) { if (minShouldMatch > 0) { diff --git a/lucene/core/src/java/org/apache/lucene/search/BooleanScorer.java b/lucene/core/src/java/org/apache/lucene/search/BooleanScorer.java index 173bb44..c08efd5 100644 --- a/lucene/core/src/java/org/apache/lucene/search/BooleanScorer.java +++ b/lucene/core/src/java/org/apache/lucene/search/BooleanScorer.java @@ -61,7 +61,7 @@ import org.apache.lucene.search.BooleanQuery.BooleanWeight; final class BooleanScorer extends BulkScorer { - private static final class BooleanScorerCollector extends SimpleCollector { + private final class BooleanScorerCollector extends SimpleCollector { private BucketTable bucketTable; private int mask; private Scorer scorer; @@ -77,18 +77,24 @@ final class BooleanScorer extends BulkScorer { final int i = doc & BucketTable.MASK; final Bucket bucket = table.buckets[i]; + final int coord = (mask & REQUIRED_MASK) == REQUIRED_MASK ? requiredNrMatch : 1; if (bucket.doc != doc) { // invalid bucket bucket.doc = doc; // set doc - bucket.score = scorer.score(); // initialize score + bucket.requiredScore = 0; // initialize required score + bucket.optionalScore = 0; // initialize optional score bucket.bits = mask; // initialize mask - bucket.coord = 1; // initialize coord - + bucket.coord = coord; // initialize coord bucket.next = table.first; // push onto valid list table.first = bucket; } else { // valid bucket - bucket.score += scorer.score(); // increment score bucket.bits |= mask; // add bits in mask - bucket.coord++; // increment coord + bucket.coord += coord; // increment coord + } + + if ((mask & REQUIRED_MASK) == REQUIRED_MASK) { // Required doc + bucket.requiredScore += scorer.score(); + } else if (mask == 0) { // Optional doc + bucket.optionalScore += scorer.score(); } } @@ -106,17 +112,18 @@ final class BooleanScorer extends BulkScorer { static final class Bucket { int doc = -1; // tells if bucket is valid - double score; // incremental score - // TODO: break out bool anyProhibited, int - // numRequiredMatched; then we can remove 32 limit on - // required clauses + // score is divided into RS and OS, so that its calculating order + // can be the same as the schedule of DAAT. + double requiredScore; // incremental required score + double optionalScore; // incremental optional score + int bits; // used for bool constraints int coord; // count of terms in score Bucket next; // next valid bucket } /** A simple hash table of document scores within a range. */ - static final class BucketTable { + final class BucketTable { public static final int SIZE = 1 << 11; public static final int MASK = SIZE - 1; @@ -140,8 +147,7 @@ final class BooleanScorer extends BulkScorer { static final class SubScorer { public BulkScorer scorer; - // TODO: re-enable this if BQ ever sends us required clauses - //public boolean required = false; + public boolean required = false; public boolean prohibited; public LeafCollector collector; public SubScorer next; @@ -149,13 +155,9 @@ final class BooleanScorer extends BulkScorer { public SubScorer(BulkScorer scorer, boolean required, boolean prohibited, LeafCollector collector, SubScorer next) { - if (required) { - throw new IllegalArgumentException("this scorer cannot handle required=true"); - } this.scorer = scorer; this.more = true; - // TODO: re-enable this if BQ ever sends us required clauses - //this.required = required; + this.required = required; this.prohibited = prohibited; this.collector = collector; this.next = next; @@ -165,20 +167,32 @@ final class BooleanScorer extends BulkScorer { private SubScorer scorers = null; private BucketTable bucketTable = new BucketTable(); private final float[] coordFactors; - // TODO: re-enable this if BQ ever sends us required clauses - //private int requiredMask = 0; + // minNrShouldMatch only applies to SHOULD clauses private final int minNrShouldMatch; private int end; private Bucket current; // Any time a prohibited clause matches we set bit 0: private static final int PROHIBITED_MASK = 1; + // Any time a required clause matches we set bit 1: + private static final int REQUIRED_MASK = 2; + // requiredNrMatch applies to MUST clauses + private final int requiredNrMatch; private final Weight weight; - BooleanScorer(BooleanWeight weight, boolean disableCoord, int minNrShouldMatch, - List optionalScorers, List prohibitedScorers, int maxCoord) throws IOException { + BooleanScorer(BooleanWeight weight, boolean disableCoord, int minNrShouldMatch, + List requiredScorers, List optionalScorers, List prohibitedScorers, + int maxCoord) throws IOException { + this.minNrShouldMatch = minNrShouldMatch; this.weight = weight; + + this.requiredNrMatch = requiredScorers.size(); + if (this.requiredNrMatch > 0) { + BulkScorer requiredScorer = new Weight.DefaultBulkScorer(new ConjunctionScorer( + this.weight, requiredScorers.toArray(new Scorer[requiredScorers.size()]))); + scorers = new SubScorer(requiredScorer, true, false, bucketTable.newCollector(REQUIRED_MASK), scorers); + } for (BulkScorer scorer : optionalScorers) { scorers = new SubScorer(scorer, false, false, bucketTable.newCollector(0), scorers); @@ -188,7 +202,7 @@ final class BooleanScorer extends BulkScorer { scorers = new SubScorer(scorer, false, true, bucketTable.newCollector(PROHIBITED_MASK), scorers); } - coordFactors = new float[optionalScorers.size() + 1]; + coordFactors = new float[requiredScorers.size() + optionalScorers.size() + 1]; for (int i = 0; i < coordFactors.length; i++) { coordFactors[i] = disableCoord ? 1.0f : weight.coord(i, maxCoord); } @@ -209,12 +223,9 @@ final class BooleanScorer extends BulkScorer { while (current != null) { // more queued // check prohibited & required - if ((current.bits & PROHIBITED_MASK) == 0) { + if ((current.bits & PROHIBITED_MASK) == 0 && + (requiredNrMatch == 0 || (current.bits & REQUIRED_MASK) == REQUIRED_MASK)) { - // TODO: re-enable this if BQ ever sends us required - // clauses - //&& (current.bits & requiredMask) == requiredMask) { - // NOTE: Lucene always passes max = // Integer.MAX_VALUE today, because we never embed // a BooleanScorer inside another (even though @@ -229,8 +240,10 @@ final class BooleanScorer extends BulkScorer { continue; } - if (current.coord >= minNrShouldMatch) { - fs.score = (float) (current.score * coordFactors[current.coord]); + if (current.coord - requiredNrMatch >= minNrShouldMatch) { + // Cast required score and optional score to float, + // in order to make the calculating procedure is the same as DAAT. + fs.score = (float) (((float) current.requiredScore + (float) current.optionalScore) * coordFactors[current.coord]); fs.doc = current.doc; fs.freq = current.coord; collector.collect(current.doc); diff --git a/lucene/core/src/java/org/apache/lucene/search/BooleanScorerInOrder.java b/lucene/core/src/java/org/apache/lucene/search/BooleanScorerInOrder.java new file mode 100644 index 0000000..2d31ccd --- /dev/null +++ b/lucene/core/src/java/org/apache/lucene/search/BooleanScorerInOrder.java @@ -0,0 +1,324 @@ +package org.apache.lucene.search; + +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import java.io.IOException; +import java.util.List; + +import org.apache.lucene.search.BooleanQuery.BooleanWeight; +import org.apache.lucene.util.FixedBitSet; + +/** + * This is an improvement of {@link BooleanScorer}. + * It supports all cases even when there're MUST clauses. + */ +final class BooleanScorerInOrder extends Scorer { + + + private static final class Bucket { + int doc; // doc id + // score is divided into RS and OS, so that its calculating order + // can be the same as the schedule of DAAT. + double requiredScore; // incremental required score + double optionalScore; // incremental optional score + int coord; // count of terms in score + } + + /** A simple hash table of document scores within a range. */ + private final class BucketTable { + static final int SIZE = 1 << 11; + static final int MASK = SIZE - 1; + + private FixedBitSet bitSet = new FixedBitSet(SIZE); + private boolean bitSetIsClean = true; + private final Bucket[] buckets = new Bucket[SIZE]; + // After collecting more documents, if there are more documents not collected. + boolean more = true; + int end = 0; + + BucketTable() { + // Pre-fill to save the lazy init when collecting + // each sub: + for(int idx=0;idx= 0) { + final Bucket bucket = buckets[bucketIndex]; + int prohibitedDocID = prohibitedScorer.docID(); + if (prohibitedDocID < bucket.doc) { + // According to the definition of .advance(), + // prohibitedDocID should be less then bucket.doc, + // before calling .advance(bucket.doc); + + // Skip to bucket.doc, so that prohibitedDocID >= bucket.doc + prohibitedDocID = prohibitedScorer.advance(bucket.doc); + } + + if (prohibitedDocID == DocIdSetIterator.NO_MORE_DOCS) { + break; + + } else if (prohibitedDocID == bucket.doc) { + // remove the prohibited bucket. + bitSet.clear(bucketIndex); + bucketIndex = bucketIndex + 1 < SIZE ? bitSet.nextSetBit(bucketIndex + 1) : -1; + + } else { // prohibitedDocID > bucket.doc + if (prohibitedDocID >= end) { + bucketIndex = -1; + } else { + bucketIndex = bitSet.nextSetBit(prohibitedDocID & MASK); + } + } + } + } + + // Scan optionalDocs to add coord and score. + // TODO: use countLeft to judge whether a bucket should be removed. +// int countLeft = optionalScorers.size(); + for (Scorer optionalScorer : optionalScorers) { +// countLeft --; + int bucketIndex = bitSet.nextSetBit(0); + while (bucketIndex >= 0) { + final Bucket bucket = buckets[bucketIndex]; + int optionalDocID = optionalScorer.docID(); + if (optionalDocID < bucket.doc) { + // According to the definition of .advance(), + // optionalDocID should be less then bucket.doc, + // before calling .advance(bucket.doc); + + // Skip to bucket.doc, so that optionalDocID >= bucket.doc + optionalDocID = optionalScorer.advance(bucket.doc); + } + + if (optionalDocID == DocIdSetIterator.NO_MORE_DOCS) { + break; + + } else if (optionalDocID == bucket.doc) { + bucket.coord ++; + bucket.optionalScore += optionalScorer.score(); + bucketIndex = bucketIndex + 1 < SIZE ? bitSet.nextSetBit(bucketIndex + 1) : -1; +// if (bucket.coord + countLeft < minNrShouldMatch) remove(bucket); + + } else { // optionalDocID > bucket.doc + // current bucket advances to prohibtedDocID. + if (optionalDocID >= end) { + bucketIndex = -1; + } else { + bucketIndex = bitSet.nextSetBit(optionalDocID & MASK); + } + } + } + } + + if (more && bitSet.nextSetBit(0) < 0) { + // If there are more docs not collected, but no doc is collected in this iteration, + // collect more again. + return collectMore(); + } + return more; + } + + + void advance(int target) throws IOException { + end = target & ~MASK; + // advance requiredConjunctionScorer to target doc + if (requiredConjunctionScorer.docID() < target) { + requiredConjunctionScorer.advance(target); + } + + // Scan prohibitedDocs to advance to target doc. + for (Scorer prohibitedScorer : prohibitedScorers) { + if (prohibitedScorer.docID() < target) { + prohibitedScorer.advance(target); + } + } + + // Scan optionalDocs to advance to target doc. + for (Scorer optionalScorer : optionalScorers) { + if (optionalScorer.docID() < target) { + optionalScorer.advance(target); + } + } + currentIndex = -1; + currentDoc = -1; + } + } + + private final BucketTable bucketTable = new BucketTable(); + private final float[] coordFactors; + private int currentIndex = -1; + private int currentDoc = -1; + + final private Scorer requiredConjunctionScorer; + final private List requiredScorers; + final private List optionalScorers; + final private List prohibitedScorers; + // minNrShouldMatch only applies to SHOULD clauses + final private int minNrShouldMatch; + + BooleanScorerInOrder(BooleanWeight weight, boolean disableCoord, int minNrShouldMatch, + List requiredScorers, List optionalScorers, List prohibitedScorers, + int maxCoord) throws IOException { + super(weight); + if (requiredScorers.size() == 0) { + throw new IllegalArgumentException("requriedScorers.size() must be > 0"); + } + if (minNrShouldMatch < 0) { + throw new IllegalArgumentException("Minimum number of optional scorers should not be negative"); + } + this.minNrShouldMatch = minNrShouldMatch; + + this.requiredScorers = requiredScorers; + this.optionalScorers = optionalScorers; + this.prohibitedScorers = prohibitedScorers; + this.requiredConjunctionScorer = new ConjunctionScorer( + this.weight, this.requiredScorers.toArray(new Scorer[this.requiredScorers.size()])); + + coordFactors = new float[requiredScorers.size() + optionalScorers.size() + 1]; + for (int i = 0; i < coordFactors.length; i++) { + coordFactors[i] = disableCoord ? 1.0f : weight.coord(i, maxCoord); + } + } + + @Override + public String toString() { + StringBuffer buffer = new StringBuffer(); + buffer.append("boolean("); + for (Scorer requiredScorer : requiredScorers) { + buffer.append("+"); + buffer.append(requiredScorer.toString()); + buffer.append(" "); + } + for (Scorer optionalScorer : optionalScorers) { + buffer.append(optionalScorer.toString()); + buffer.append(" "); + } + for (Scorer prohibitedScorer : prohibitedScorers) { + buffer.append("-"); + buffer.append(prohibitedScorer.toString()); + buffer.append(" "); + } + // Delete the last whitespace + if (buffer.length() > "boolean(".length()) { + buffer.deleteCharAt(buffer.length() - 1); + } + buffer.append(")"); + return buffer.toString(); + } + + @Override + public float score() throws IOException { + if (currentIndex < 0) return Float.NaN; + final Bucket bucket = bucketTable.buckets[currentIndex]; + // Cast required score and optional score to float, + // in order to make the calculating procedure is the same as DAAT. + return (float) (((float) bucket.requiredScore + (float) bucket.optionalScore) * coordFactors[bucket.coord]); + } + + @Override + public int freq() throws IOException { + return currentIndex >= 0 ? bucketTable.buckets[currentIndex].coord : 0; + } + + @Override + public int docID() { + return currentDoc; + } + + @Override + public int nextDoc() throws IOException { + if (currentIndex >= 0) { + currentIndex = currentIndex + 1 < BucketTable.SIZE ? bucketTable.bitSet.nextSetBit(currentIndex + 1) : -1; + } + if (bucketTable.more && currentIndex < 0) { + bucketTable.collectMore(); + currentIndex = bucketTable.bitSet.nextSetBit(0); + } + + while (currentIndex >= 0) { + final Bucket bucket = bucketTable.buckets[currentIndex]; + if (bucket.coord - requiredScorers.size() >= minNrShouldMatch) { + return currentDoc = bucket.doc; + + } else { + currentIndex = currentIndex + 1 < BucketTable.SIZE ? bucketTable.bitSet.nextSetBit(currentIndex + 1) : -1; + } + + if (bucketTable.more && currentIndex < 0) { + bucketTable.collectMore(); + currentIndex = bucketTable.bitSet.nextSetBit(0); + } + } + return currentDoc = DocIdSetIterator.NO_MORE_DOCS; + } + + @Override + public int advance(int target) throws IOException { + // TODO: should use trickier method to skip documents. + if (target < bucketTable.end) { + return slowAdvance(target); + + } else { + bucketTable.advance(target); + return slowAdvance(target); + } + } + + @Override + public long cost() { + return requiredConjunctionScorer.cost(); + } +} diff --git a/lucene/core/src/java/org/apache/lucene/search/ConjunctionScorer.java b/lucene/core/src/java/org/apache/lucene/search/ConjunctionScorer.java index 3e81187..e303ed8 100644 --- a/lucene/core/src/java/org/apache/lucene/search/ConjunctionScorer.java +++ b/lucene/core/src/java/org/apache/lucene/search/ConjunctionScorer.java @@ -102,12 +102,11 @@ class ConjunctionScorer extends Scorer { @Override public float score() throws IOException { - // TODO: sum into a double and cast to float if we ever send required clauses to BS1 - float sum = 0.0f; + double sum = 0.0f; for (DocsAndFreqs docs : docsAndFreqs) { sum += docs.scorer.score(); } - return sum * coord; + return (float) (sum * coord); } @Override diff --git a/lucene/core/src/test/org/apache/lucene/search/TestBooleanScorer.java b/lucene/core/src/test/org/apache/lucene/search/TestBooleanScorer.java index 358a513..69d71af 100644 --- a/lucene/core/src/test/org/apache/lucene/search/TestBooleanScorer.java +++ b/lucene/core/src/test/org/apache/lucene/search/TestBooleanScorer.java @@ -96,7 +96,7 @@ public class TestBooleanScorer extends LuceneTestCase { } }}; - BooleanScorer bs = new BooleanScorer(weight, false, 1, Arrays.asList(scorers), Collections.emptyList(), scorers.length); + BooleanScorer bs = new BooleanScorer(weight, false, 1, Collections.emptyList(), Arrays.asList(scorers), Collections.emptyList(), scorers.length); final List hits = new ArrayList<>(); bs.score(new SimpleCollector() { diff --git a/lucene/core/src/test/org/apache/lucene/search/TestBooleanUnevenly.java b/lucene/core/src/test/org/apache/lucene/search/TestBooleanUnevenly.java new file mode 100644 index 0000000..2e393de --- /dev/null +++ b/lucene/core/src/test/org/apache/lucene/search/TestBooleanUnevenly.java @@ -0,0 +1,132 @@ +package org.apache.lucene.search; + +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import org.apache.lucene.analysis.MockAnalyzer; +import org.apache.lucene.document.Document; +import org.apache.lucene.document.Field; +import org.apache.lucene.index.IndexReader; +import org.apache.lucene.index.RandomIndexWriter; +import org.apache.lucene.index.Term; +import org.apache.lucene.store.Directory; +import org.apache.lucene.util.LuceneTestCase; +import org.junit.AfterClass; +import org.junit.BeforeClass; +import org.junit.Test; + +/** + * {@link BooleanNovelScorer} should be tested, when hit documents + * are very unevenly distributed. + */ +public class TestBooleanUnevenly extends LuceneTestCase { + private static IndexSearcher searcher; + private static IndexReader reader; + + public static final String field = "field"; + private static Directory directory; + + private static int count1 = 0; + + @BeforeClass + public static void beforeClass() throws Exception { + directory = newDirectory(); + RandomIndexWriter w = new RandomIndexWriter(random(), directory, new MockAnalyzer(random())); + Document doc; + for (int i=0;i<2;i++) { + for (int j=0;j<2048;j++) { + doc = new Document(); + doc.add(newTextField(field, "1", Field.Store.NO)); + count1 ++; + w.addDocument(doc); + } + for (int j=0;j<2048;j++) { + doc = new Document(); + doc.add(newTextField(field, "2", Field.Store.NO)); + w.addDocument(doc); + } + doc = new Document(); + doc.add(newTextField(field, "1", Field.Store.NO)); + count1 ++; + w.addDocument(doc); + for (int j=0;j<2048;j++) { + doc = new Document(); + doc.add(newTextField(field, "2", Field.Store.NO)); + w.addDocument(doc); + } + } + reader = w.getReader(); + searcher = newSearcher(reader); + w.shutdown(); + } + + @AfterClass + public static void afterClass() throws Exception { + reader.close(); + directory.close(); + searcher = null; + reader = null; + directory = null; + } + + @Test + public void testQueries01() throws Exception { + BooleanQuery query = new BooleanQuery(); + query.add(new TermQuery(new Term(field, "1")), BooleanClause.Occur.MUST); + query.add(new TermQuery(new Term(field, "1")), BooleanClause.Occur.SHOULD); + query.add(new TermQuery(new Term(field, "2")), BooleanClause.Occur.SHOULD); + + TopScoreDocCollector collector = TopScoreDocCollector.create(1000, false); + searcher.search(query, null, collector); + TopDocs tops1 = collector.topDocs(); + ScoreDoc[] hits1 = tops1.scoreDocs; + int hitsNum1 = tops1.totalHits; + + collector = TopScoreDocCollector.create(1000, true); + searcher.search(query, null, collector); + TopDocs tops2 = collector.topDocs(); + ScoreDoc[] hits2 = tops2.scoreDocs; + int hitsNum2 = tops2.totalHits; + + assertEquals(hitsNum1, count1); + assertEquals(hitsNum2, count1); + CheckHits.checkEqual(query, hits1, hits2); + } + + @Test + public void testQueries02() throws Exception { + BooleanQuery query = new BooleanQuery(); + query.add(new TermQuery(new Term(field, "1")), BooleanClause.Occur.SHOULD); + query.add(new TermQuery(new Term(field, "1")), BooleanClause.Occur.SHOULD); + + TopScoreDocCollector collector = TopScoreDocCollector.create(1000, false); + searcher.search(query, null, collector); + TopDocs tops1 = collector.topDocs(); + ScoreDoc[] hits1 = tops1.scoreDocs; + int hitsNum1 = tops1.totalHits; + + collector = TopScoreDocCollector.create(1000, true); + searcher.search(query, null, collector); + TopDocs tops2 = collector.topDocs(); + ScoreDoc[] hits2 = tops2.scoreDocs; + int hitsNum2 = tops2.totalHits; + + assertEquals(hitsNum1, count1); + assertEquals(hitsNum2, count1); + CheckHits.checkEqual(query, hits1, hits2); + } +}