diff --git a/lucene/core/src/java/org/apache/lucene/search/BooleanArrayScorer.java b/lucene/core/src/java/org/apache/lucene/search/BooleanArrayScorer.java new file mode 100644 index 0000000..ad791eb --- /dev/null +++ b/lucene/core/src/java/org/apache/lucene/search/BooleanArrayScorer.java @@ -0,0 +1,322 @@ +package org.apache.lucene.search; + +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import java.io.IOException; +import java.util.List; + +import org.apache.lucene.search.BooleanQuery.BooleanWeight; + +/** + * This is an improvement of {@link BooleanScorer}. + * It only supports cases where there is at least one MUST clause. + */ +final class BooleanArrayScorer extends Scorer { + + private static final class Bucket { + int doc; // doc id + // score is divided into RS and OS, so that its calculating order + // can be the same as the DAAT procedure. + double requiredScore; // incremental required score + double optionalScore; // incremental optional score + int coord; // count of terms in score + boolean valid; // valid bucket + } + + /** A simple hash table of document scores within a range. */ + private final class BucketTable { + static final int SIZE = 1 << 8; + + private final Bucket[] buckets = new Bucket[SIZE]; + // After collecting more documents, if there are more documents not collected. + boolean more = true; + private int numOfBuckets = 0; + private int numOfValidBuckets = 0; + + BucketTable() { + // Pre-fill to save the lazy init when collecting + // each sub: + for(int idx=0;idxmore is true, while numOfBuckets is 0. + * @return true if more matching documents may remain. + */ + boolean collectMore() throws IOException { + numOfBuckets = 0; + + // Scan requiredDocs full fill the bucket table + while (numOfBuckets < SIZE) { + final int requiredDocID = requiredConjunctionScorer.nextDoc(); + if (requiredDocID == DocIdSetIterator.NO_MORE_DOCS) { + more = false; + break; + } + final Bucket bucket = buckets[numOfBuckets ++]; + bucket.doc = requiredDocID; + bucket.coord = requiredScorers.size(); + bucket.requiredScore = requiredConjunctionScorer.score(); + bucket.optionalScore = 0; + bucket.valid = true; + } + numOfValidBuckets = numOfBuckets; + + // Scan prohibitedDocs to remove docs from bucket table. + for (Scorer prohibitedScorer : prohibitedScorers) { + int i = 0; + while (i < numOfBuckets) { + final Bucket bucket = buckets[i]; + int prohibitedDocID = prohibitedScorer.docID(); + if (prohibitedDocID < bucket.doc) { + // According to the definition of .advance(), + // prohibitedDocID should be less then bucket.doc, + // before calling .advance(bucket.doc); + + // Skip to bucket.doc, so that prohibitedDocID >= bucket.doc + prohibitedDocID = prohibitedScorer.advance(bucket.doc); + } + + if (prohibitedDocID == DocIdSetIterator.NO_MORE_DOCS) { + break; + + } else if (prohibitedDocID == bucket.doc) { + // remove the prohibited bucket. + if (bucket.valid) { + numOfValidBuckets --; + bucket.valid = false; + } + i ++; + + } else { // prohibitedDocID > bucket.doc + i = skipsTo(i, prohibitedDocID); + } + } + } + + // Scan optionalDocs to add coord and score. + // TODO: use countLeft to judge whether a bucket should be removed. +// int countLeft = optionalScorers.size(); + for (Scorer optionalScorer : optionalScorers) { +// countLeft --; + int i = 0; + while (i < numOfBuckets) { + final Bucket bucket = buckets[i]; + int optionalDocID = optionalScorer.docID(); + if (optionalDocID < bucket.doc) { + // According to the definition of .advance(), + // optionalDocID should be less then bucket.doc, + // before calling .advance(bucket.doc); + + // Skip to bucket.doc, so that optionalDocID >= bucket.doc + optionalDocID = optionalScorer.advance(bucket.doc); + } + + if (optionalDocID == DocIdSetIterator.NO_MORE_DOCS) { + break; + + } else if (optionalDocID == bucket.doc) { + bucket.coord ++; + bucket.optionalScore += optionalScorer.score(); + i ++; +// if (oldBucket.coord + countLeft < minNrShouldMatch) remove(oldBucket); + + } else { // optionalDocID > bucket.doc + // current bucket skips to optionalDocID. + i = skipsTo(i, optionalDocID); + } + } + } + + if (more && numOfValidBuckets == 0) { + // If there are more docs not collected, but no doc is collected in this iteration, + // collect more again. + return collectMore(); + } + return more; + } + + + /** + * Skips to target from begin.
+ * NOTE: Undefined when buckets[i].doc >= target. + * @param begin the index begin to skip, MUST s.t. -1 ≤ begin. + * @param target the target doc to skip. + * @return the first index i s.t. buckets[i].doc >= target. + */ + int skipsTo(int begin, int target) { + final int DELTA = 16; + int i = begin + DELTA; + while (i < numOfBuckets && buckets[i].doc < target) { + i += DELTA; + } + + int end = i; + if (end >= numOfBuckets) { + end = numOfBuckets; + + } else if (buckets[end].doc == target) { + return end; + } + + i = i - DELTA + 1; + while (i < end && buckets[i].doc < target) { + i ++; + } + return i; + } + + void advance(int target) throws IOException { + // advance requiredConjunctionScorer to target doc + if (requiredConjunctionScorer.docID() < target) { + requiredConjunctionScorer.advance(target); + } + + // Scan prohibitedDocs to advance to target doc + for (Scorer prohibitedScorer : prohibitedScorers) { + if (prohibitedScorer.docID() < target) { + prohibitedScorer.advance(target); + } + } + + // Scan optionalDocs to advance to target doc + for (Scorer optionalScorer : optionalScorers) { + if (optionalScorer.docID() < target) { + optionalScorer.advance(target); + } + } + } + } + + private final BucketTable bucketTable = new BucketTable(); + private final float[] coordFactors; + private int currentIndex = -1; + private int currentDoc = -1; + + final private Scorer requiredConjunctionScorer; + final private List requiredScorers; + final private List optionalScorers; + final private List prohibitedScorers; + // minNrShouldMatch only applies to SHOULD clauses + final private int minNrShouldMatch; + + BooleanArrayScorer(BooleanWeight weight, boolean disableCoord, int minNrShouldMatch, + List requiredScorers, List optionalScorers, List prohibitedScorers, + int maxCoord) throws IOException { + super(weight); + if (requiredScorers.size() == 0) { + throw new IllegalArgumentException("requriedScorers.size() must be > 0"); + } + if (minNrShouldMatch < 0) { + throw new IllegalArgumentException("Minimum number of optional scorers should not be negative"); + } + this.minNrShouldMatch = minNrShouldMatch; + + this.requiredScorers = requiredScorers; + this.optionalScorers = optionalScorers; + this.prohibitedScorers = prohibitedScorers; + this.requiredConjunctionScorer = new ConjunctionScorer( + this.weight, this.requiredScorers.toArray(new Scorer[this.requiredScorers.size()])); + + coordFactors = new float[requiredScorers.size() + optionalScorers.size() + 1]; + for (int i = 0; i < coordFactors.length; i++) { + coordFactors[i] = disableCoord ? 1.0f : weight.coord(i, maxCoord); + } + } + + @Override + public String toString() { + StringBuffer buffer = new StringBuffer(); + buffer.append("boolean("); + for (Scorer requiredScorer : requiredScorers) { + buffer.append("+"); + buffer.append(requiredScorer.toString()); + buffer.append(" "); + } + for (Scorer optionalScorer : optionalScorers) { + buffer.append(optionalScorer.toString()); + buffer.append(" "); + } + for (Scorer prohibitedScorer : prohibitedScorers) { + buffer.append("-"); + buffer.append(prohibitedScorer.toString()); + buffer.append(" "); + } + // Delete the last whitespace + if (buffer.length() > "boolean(".length()) { + buffer.deleteCharAt(buffer.length() - 1); + } + buffer.append(")"); + return buffer.toString(); + } + + @Override + public float score() throws IOException { + final Bucket bucket = bucketTable.buckets[currentIndex]; + // Cast required score and optional score to float, + // in order to make the calculating procedure is the same as DAAT. + return (float) ((float) bucket.requiredScore + (float) bucket.optionalScore) * coordFactors[bucket.coord]; + } + + @Override + public int freq() throws IOException { + return bucketTable.buckets[currentIndex].coord; + } + + @Override + public int docID() { + return currentDoc; + } + + @Override + public int nextDoc() throws IOException { + currentIndex ++; + if (bucketTable.more && currentIndex >= bucketTable.numOfBuckets) { + bucketTable.collectMore(); + currentIndex = 0; + } + + while (currentIndex < bucketTable.numOfBuckets) { + final Bucket bucket = bucketTable.buckets[currentIndex]; + + if (bucket.valid && bucket.coord - requiredScorers.size() >= minNrShouldMatch) { + return currentDoc = bucket.doc; + } + + currentIndex ++; + if (bucketTable.more && currentIndex >= bucketTable.numOfBuckets) { + bucketTable.collectMore(); + currentIndex = 0; + } + } + return currentDoc = DocIdSetIterator.NO_MORE_DOCS; + } + + @Override + public int advance(int target) throws IOException { + return slowAdvance(target); + } + + @Override + public long cost() { + return requiredConjunctionScorer.cost(); + } +} diff --git a/lucene/core/src/java/org/apache/lucene/search/BooleanLinkedScorer.java b/lucene/core/src/java/org/apache/lucene/search/BooleanLinkedScorer.java new file mode 100644 index 0000000..0ab8e19 --- /dev/null +++ b/lucene/core/src/java/org/apache/lucene/search/BooleanLinkedScorer.java @@ -0,0 +1,369 @@ +package org.apache.lucene.search; + +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import java.io.IOException; +import java.util.List; + +import org.apache.lucene.search.BooleanQuery.BooleanWeight; +import org.apache.lucene.util.FixedBitSet; + +/** + * This is an improvement of {@link BooleanScorer}. + * It only supports cases where there is at least one MUST clause. + */ +final class BooleanLinkedScorer extends Scorer { + + + private static final class Bucket { + int doc; // doc id +// double score; // incremental score + double requiredScore; + double optionalScore; + int coord; // count of terms in score + Bucket next; // next valid bucket + Bucket prev; // previous valid bucket + } + + /** A simple hash table of document scores within a range. */ + private final class BucketTable { + static final int SIZE = 1 << 11; + static final int MASK = SIZE - 1; + + private FixedBitSet bitSet = new FixedBitSet(SIZE); + private boolean bitSetIsClean = true; + private final Bucket[] buckets = new Bucket[SIZE]; + private Bucket first = null; // head of valid list + private Bucket last = null; // tail of valid list + // After collecting more documents, if there are more documents not collected. + boolean more = true; + int end = 0; + + BucketTable() { + // Pre-fill to save the lazy init when collecting + // each sub: + for(int idx=0;idx= bucket.doc + prohibitedDocID = prohibitedScorer.advance(bucket.doc); + } + + if (prohibitedDocID == DocIdSetIterator.NO_MORE_DOCS) { + break; + + } else if (prohibitedDocID == bucket.doc) { + // remove the prohibited bucket. + Bucket oldBucket = bucket; + bucket = bucket.next; + remove(oldBucket); + + } else { // prohibitedDocID > bucket.doc + if (prohibitedDocID >= end) { + bucket = null; + } else { + int bucketIndex = bitSet.nextSetBit(prohibitedDocID & MASK); + bucket = bucketIndex < 0 ? null : buckets[bucketIndex]; + } + } + } + } + + // Scan optionalDocs to add coord and score. + // TODO: use countLeft to judge whether a bucket should be removed. +// int countLeft = optionalScorers.size(); + for (Scorer optionalScorer : optionalScorers) { +// countLeft --; + Bucket bucket = first; + while (bucket != null) { + int optionalDocID = optionalScorer.docID(); + if (optionalDocID < bucket.doc) { + // According to the definition of .advance(), + // optionalDocID should be less then bucket.doc, + // before calling .advance(bucket.doc); + + // Skip to bucket.doc, so that optionalDocID >= bucket.doc + optionalDocID = optionalScorer.advance(bucket.doc); + } + + if (optionalDocID == DocIdSetIterator.NO_MORE_DOCS) { + break; + + } else if (optionalDocID == bucket.doc) { + Bucket oldBucket = bucket; + bucket = bucket.next; + oldBucket.coord ++; +// oldBucket.score += optionalScorer.score(); + oldBucket.optionalScore += optionalScorer.score(); +// if (oldBucket.coord + countLeft < minNrShouldMatch) remove(oldBucket); + + } else { // docID > bucket.doc + // current bucket advances to prohibtedDocID. + if (optionalDocID >= end) { + bucket = null; + } else { + int bucketIndex = bitSet.nextSetBit(optionalDocID & MASK); + bucket = bucketIndex < 0 ? null : buckets[bucketIndex]; + } + } + } + } + + if (more && first == null) { + // If there are more docs not collected, but no doc is collected in this iteration, + // collect more again. + return collectMore(); + } + return more; + } + + /** + * Remove bucket from the double-linked list [first, last]. + * @param bucket bucket to be removed from the double-linked list. + */ + private void remove(Bucket bucket) { + bitSet.clear(bucket.doc & MASK); // remove the bucket on bitset + if (first == bucket && last == bucket) { + first = last = null; + return; + } + if (first == bucket) { + first = first.next; + first.prev = null; + return; + } + if (last == bucket) { + last = last.prev; + last.next = null; + return; + } + Bucket prev = bucket.prev; + Bucket next = bucket.next; + prev.next = next; + next.prev = prev; + } + + /** + * Add bucket to the tail of double-linked list [first, last]. + * @param bucket bucket to be added to the double-linked list. + */ + private void add(Bucket bucket) { + bitSet.set(bucket.doc & MASK); // add the bucket to bitset + bitSetIsClean = false; // the bitset is not clean now + if (first == null) { + first = last = bucket; + bucket.prev = null; + bucket.next = null; + return; + } + last.next = bucket; + bucket.prev = last; + bucket.next = null; + last = bucket; + } + + void advance(int target) throws IOException { + end = target & ~MASK; + // advance requiredConjunctionScorer to target doc + if (requiredConjunctionScorer.docID() < target) { + requiredConjunctionScorer.advance(target); + } + + // Scan prohibitedDocs to advance to target doc. + for (Scorer prohibitedScorer : prohibitedScorers) { + if (prohibitedScorer.docID() < target) { + prohibitedScorer.advance(target); + } + } + + // Scan optionalDocs to advance to target doc. + for (Scorer optionalScorer : optionalScorers) { + if (optionalScorer.docID() < target) { + optionalScorer.advance(target); + } + } + current = null; + currentDoc = -1; + } + } + + private final BucketTable bucketTable = new BucketTable(); + private final float[] coordFactors; + // minNrShouldMatch only applies to SHOULD clauses + private Bucket current = null; + private int currentDoc = -1; + + final private Scorer requiredConjunctionScorer; + final private List requiredScorers; + final private List optionalScorers; + final private List prohibitedScorers; + final private int minNrShouldMatch; + + BooleanLinkedScorer(BooleanWeight weight, boolean disableCoord, int minNrShouldMatch, + List requiredScorers, List optionalScorers, List prohibitedScorers, + int maxCoord) throws IOException { + super(weight); + if (requiredScorers.size() == 0) { + throw new IllegalArgumentException("requriedScorers.size() must be > 0"); + } + if (minNrShouldMatch < 0) { + throw new IllegalArgumentException("Minimum number of optional scorers should not be negative"); + } + this.minNrShouldMatch = minNrShouldMatch; + + this.requiredScorers = requiredScorers; + this.optionalScorers = optionalScorers; + this.prohibitedScorers = prohibitedScorers; + this.requiredConjunctionScorer = new ConjunctionScorer( + this.weight, this.requiredScorers.toArray(new Scorer[this.requiredScorers.size()])); + + coordFactors = new float[requiredScorers.size() + optionalScorers.size() + 1]; + for (int i = 0; i < coordFactors.length; i++) { + coordFactors[i] = disableCoord ? 1.0f : weight.coord(i, maxCoord); + } + } + + @Override + public String toString() { + StringBuffer buffer = new StringBuffer(); + buffer.append("boolean("); + for (Scorer requiredScorer : requiredScorers) { + buffer.append("+"); + buffer.append(requiredScorer.toString()); + buffer.append(" "); + } + for (Scorer optionalScorer : optionalScorers) { + buffer.append(optionalScorer.toString()); + buffer.append(" "); + } + for (Scorer prohibitedScorer : prohibitedScorers) { + buffer.append("-"); + buffer.append(prohibitedScorer.toString()); + buffer.append(" "); + } + // Delete the last whitespace + if (buffer.length() > "boolean(".length()) { + buffer.deleteCharAt(buffer.length() - 1); + } + buffer.append(")"); + return buffer.toString(); + } + + @Override + public float score() throws IOException { + return current != null ? + ((float) current.requiredScore + (float) current.optionalScore) * coordFactors[current.coord] : Float.NaN; + } + + @Override + public int freq() throws IOException { + return current != null ? current.coord : 0; + } + + @Override + public int docID() { + return currentDoc; + } + + @Override + public int nextDoc() throws IOException { + if (current != null) { + current = current.next; + } + if (bucketTable.more && current == null) { + bucketTable.collectMore(); + current = bucketTable.first; + } + + while (current != null) { + if (current.coord - requiredScorers.size() >= minNrShouldMatch) { + currentDoc = current.doc; + return currentDoc; + + } else { + current = current.next; + } + + if (bucketTable.more && current == null) { + bucketTable.collectMore(); + current = bucketTable.first; + } + } + return currentDoc = DocIdSetIterator.NO_MORE_DOCS; + } + + @Override + public int advance(int target) throws IOException { + // TODO: should use trickier method to skip documents. + if (target < bucketTable.end) { + return slowAdvance(target); + + } else { + bucketTable.advance(target); + return slowAdvance(target); + } + } + + @Override + public long cost() { + return requiredConjunctionScorer.cost(); + } +} diff --git a/lucene/core/src/java/org/apache/lucene/search/BooleanMixedScorerDecider.java b/lucene/core/src/java/org/apache/lucene/search/BooleanMixedScorerDecider.java new file mode 100644 index 0000000..8708c09 --- /dev/null +++ b/lucene/core/src/java/org/apache/lucene/search/BooleanMixedScorerDecider.java @@ -0,0 +1,191 @@ +package org.apache.lucene.search; + +import java.io.IOException; +import java.util.Arrays; +import java.util.List; + +import org.apache.lucene.search.BooleanArrayScorer; +import org.apache.lucene.search.BooleanQuery.BooleanWeight; + +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +final public class BooleanMixedScorerDecider { + + final private Scorer scorer; + final private BooleanWeight weight; + final private boolean disableCoord; + final private int maxCoord; + + BooleanMixedScorerDecider(BooleanWeight weight, boolean disableCoord, int minShouldMatch, + List required, List optional, List prohibited, + int maxCoord) throws IOException { + this.weight = weight; + this.disableCoord = disableCoord; + this.maxCoord = maxCoord; + scorer = simplifiedScorer(required, optional, prohibited, minShouldMatch); + } + + public Scorer getScorer() { + return scorer; + } + + /** + * scorer simplifications + */ + private Scorer simplifiedScorer(List required, List optional, List prohibited, + int minNrShouldMatch) throws IOException { + int minShouldMatch = minNrShouldMatch; + // scorer simplifications: + + if (optional.size() == minShouldMatch) { + // any optional clauses are in fact required + required.addAll(optional); + optional.clear(); + minShouldMatch = 0; + } + + if (required.isEmpty() && optional.isEmpty()) { + // no required and optional clauses. + return null; + } else if (optional.size() < minShouldMatch) { + // either >1 req scorer, or there are 0 req scorers and at least 1 + // optional scorer. Therefore if there are not enough optional scorers + // no documents will be matched by the query + return null; + } + + // TODO: Find a better trigger condition on calling BooleanNovalScorer + if (!required.isEmpty() && optional.size() > 3) { + float times = (float) required.get(0).cost() / optional.get(0).cost(); + if (times < 2) return new BooleanArrayScorer(weight, disableCoord, minShouldMatch, required, optional, prohibited, maxCoord); + } + if (!required.isEmpty() && prohibited.size() > 3) { + float times = (float) required.get(0).cost() / prohibited.get(0).cost(); + if (times < 2) return new BooleanArrayScorer(weight, disableCoord, minShouldMatch, required, optional, prohibited, maxCoord); + } + + // three cases: conjunction, disjunction, or mix + + // pure conjunction + if (optional.isEmpty()) { + return excl(req(required, disableCoord), prohibited); + } + + // pure disjunction + if (required.isEmpty()) { + return excl(opt(optional, minShouldMatch, disableCoord), prohibited); + } + + // conjunction-disjunction mix: + // we create the required and optional pieces with coord disabled, and then + // combine the two: if minNrShouldMatch > 0, then its a conjunction: because the + // optional side must match. otherwise its required + optional, factoring the + // number of optional terms into the coord calculation + + Scorer req = excl(req(required, true), prohibited); + Scorer opt = opt(optional, minShouldMatch, true); + + // TODO: clean this up: its horrible + if (disableCoord) { + if (minShouldMatch > 0) { + return new ConjunctionScorer(weight, new Scorer[] { req, opt }, 1F); + } else { + return new ReqOptSumScorer(req, opt); + } + } else if (optional.size() == 1) { + if (minShouldMatch > 0) { + return new ConjunctionScorer(weight, new Scorer[] { req, opt }, weight.coord(required.size()+1, maxCoord)); + } else { + float coordReq = weight.coord(required.size(), maxCoord); + float coordBoth = weight.coord(required.size() + 1, maxCoord); + return new BooleanTopLevelScorers.ReqSingleOptScorer(req, opt, coordReq, coordBoth); + } + } else { + if (minShouldMatch > 0) { + return new BooleanTopLevelScorers.CoordinatingConjunctionScorer(weight, coords(), req, required.size(), opt); + } else { + return new BooleanTopLevelScorers.ReqMultiOptScorer(req, opt, required.size(), coords()); + } + } + } + + private Scorer req(List required, boolean disableCoord) { + if (required.size() == 1) { + Scorer req = required.get(0); + if (!disableCoord && maxCoord > 1) { + return new BooleanTopLevelScorers.BoostedScorer(req, weight.coord(1, maxCoord)); + } else { + return req; + } + } else { + return new ConjunctionScorer(weight, + required.toArray(new Scorer[required.size()]), + disableCoord ? 1.0F : weight.coord(required.size(), maxCoord)); + } + } + + private Scorer excl(Scorer main, List prohibited) throws IOException { + if (prohibited.isEmpty()) { + return main; + } else if (prohibited.size() == 1) { + return new ReqExclScorer(main, prohibited.get(0)); + } else { + float coords[] = new float[prohibited.size()+1]; + Arrays.fill(coords, 1F); + return new ReqExclScorer(main, + new DisjunctionSumScorer(weight, + prohibited.toArray(new Scorer[prohibited.size()]), + coords)); + } + } + + private Scorer opt(List optional, int minShouldMatch, boolean disableCoord) throws IOException { + if (optional.size() == 1) { + Scorer opt = optional.get(0); + if (!disableCoord && maxCoord > 1) { + return new BooleanTopLevelScorers.BoostedScorer(opt, weight.coord(1, maxCoord)); + } else { + return opt; + } + } else { + float coords[]; + if (disableCoord) { + coords = new float[optional.size()+1]; + Arrays.fill(coords, 1F); + } else { + coords = coords(); + } + if (minShouldMatch > 1) { + return new MinShouldMatchSumScorer(weight, optional, minShouldMatch, coords); + } else { + return new DisjunctionSumScorer(weight, + optional.toArray(new Scorer[optional.size()]), + coords); + } + } + } + + private float[] coords() { + float[] coords = new float[maxCoord+1]; + coords[0] = 0F; + for (int i = 1; i < coords.length; i++) { + coords[i] = weight.coord(i, maxCoord); + } + return coords; + } +} diff --git a/lucene/core/src/java/org/apache/lucene/search/BooleanQuery.java b/lucene/core/src/java/org/apache/lucene/search/BooleanQuery.java index 4d7635d..ea299ce 100644 --- a/lucene/core/src/java/org/apache/lucene/search/BooleanQuery.java +++ b/lucene/core/src/java/org/apache/lucene/search/BooleanQuery.java @@ -20,6 +20,7 @@ package org.apache.lucene.search; import java.io.IOException; import java.util.ArrayList; import java.util.Arrays; +import java.util.Collections; import java.util.Iterator; import java.util.List; import java.util.Set; @@ -314,9 +315,15 @@ public class BooleanQuery extends Query implements Iterable { return super.bulkScorer(context, scoreDocsInOrder, acceptDocs); } + List required = new ArrayList(); List prohibited = new ArrayList(); +// List prohibitedScorers = new ArrayList(); List optional = new ArrayList(); +// List optionalScorers = new ArrayList(); Iterator cIter = clauses.iterator(); +// float requiredCost = 0; +// float prohibitedCost = 0; +// float optionalCost = 0; for (Weight w : weights) { BooleanClause c = cIter.next(); BulkScorer subScorer = w.bulkScorer(context, false, acceptDocs); @@ -328,15 +335,95 @@ public class BooleanQuery extends Query implements Iterable { // TODO: there are some cases where BooleanScorer // would handle conjunctions faster than // BooleanScorer2... - return super.bulkScorer(context, scoreDocsInOrder, acceptDocs); + + Scorer requiredSubScorer = w.scorer(context, acceptDocs); + // if no doc matches required, then return null to say + // no doc matches this Boolean Query. + if (requiredSubScorer == null) { + return null; + } + required.add(requiredSubScorer); +// requiredCost += requiredSubScorer.cost(); + } else if (c.isProhibited()) { prohibited.add(subScorer); +// prohibitedCost += subScorer.cost(); } else { optional.add(subScorer); +// optionalCost += subScorer.cost(); } } - - return new BooleanScorer(this, disableCoord, minNrShouldMatch, optional, prohibited, maxCoord); + return new BooleanScorer(this, disableCoord, minNrShouldMatch, required, optional, prohibited, maxCoord); + } + + private Scorer getMixScorer(List required, List optional, + Scorer req, Scorer opt, int minShouldMatch) { + // TODO: clean this up: its horrible + if (disableCoord) { + if (minShouldMatch > 0) { + return new ConjunctionScorer(this, new Scorer[] { req, opt }, 1F); + } else { + return new ReqOptSumScorer(req, opt); + } + } else if (optional.size() == 1) { + if (minShouldMatch > 0) { + return new ConjunctionScorer(this, new Scorer[] { req, opt }, coord(required.size()+1, maxCoord)); + } else { + float coordReq = coord(required.size(), maxCoord); + float coordBoth = coord(required.size() + 1, maxCoord); + return new BooleanTopLevelScorers.ReqSingleOptScorer(req, opt, coordReq, coordBoth); + } + } else { + if (minShouldMatch > 0) { + return new BooleanTopLevelScorers.CoordinatingConjunctionScorer(this, coords(), req, required.size(), opt); + } else { + return new BooleanTopLevelScorers.ReqMultiOptScorer(req, opt, required.size(), coords()); + } + } + } + + private Class scorerClassForNot(float requiredCost, int prohibitedSize, float prohibitedCost) { + if (prohibitedSize > 3 && prohibitedSize < 50 && + requiredCost > 100000F && prohibitedCost > 100000F) { + return BooleanScorer.class; + } + if (prohibitedSize > 50) { // Tons + if (prohibitedCost > 100000F) { // HighNot + return BooleanLinkedScorer.class; + } + if (requiredCost < 100000F && prohibitedCost < 100000F) { // LowAnd, LowNot + return BooleanArrayScorer.class; + } + return null; + } + if (prohibitedSize > 3) { // Some + if (requiredCost < 100000F && prohibitedCost > 100000F) { // LowAnd, HighNot + return BooleanArrayScorer.class; + } + } + return null; + } + + private Class scorerClassForOr(float requiredCost, int optionalSize, float optionalCost) { + if (optionalSize > 3 && optionalSize < 50 && + optionalCost > 100000F && optionalCost > 100000F) { + return BooleanScorer.class; + } + if (optionalSize > 50) { // Tons + if (requiredCost < 100000F) { // LowAnd + return BooleanArrayScorer.class; + } + if (requiredCost > 100000F && optionalCost > 100000F) { // HighAnd, HighOr + return BooleanLinkedScorer.class; + } + return null; + } + if (optionalSize > 3) { // Some + if (optionalCost > 100000F) { // HighOr + return BooleanArrayScorer.class; + } + } + return null; } @Override @@ -347,6 +434,10 @@ public class BooleanQuery extends Query implements Iterable { // we will optimize and move these to required, making this 0 int minShouldMatch = minNrShouldMatch; + float requiredCost = Float.MAX_VALUE; + float optionalCost = 0F; + float prohibitedCost = 0F; + List required = new ArrayList<>(); List prohibited = new ArrayList<>(); List optional = new ArrayList<>(); @@ -360,15 +451,20 @@ public class BooleanQuery extends Query implements Iterable { } } else if (c.isRequired()) { required.add(subScorer); + requiredCost = Math.min(requiredCost, subScorer.cost()); } else if (c.isProhibited()) { prohibited.add(subScorer); + prohibitedCost += subScorer.cost(); } else { optional.add(subScorer); + optionalCost += subScorer.cost(); } } + requiredCost *= Math.pow(0.97, required.size()); + prohibitedCost /= prohibited.size(); + optionalCost /= optional.size(); // scorer simplifications: - if (optional.size() == minShouldMatch) { // any optional clauses are in fact required required.addAll(optional); @@ -385,50 +481,73 @@ public class BooleanQuery extends Query implements Iterable { // no documents will be matched by the query return null; } - - // three cases: conjunction, disjunction, or mix - - // pure conjunction - if (optional.isEmpty()) { - return excl(req(required, disableCoord), prohibited); - } - + + // three cases: disjunction, conjunction, or mix // pure disjunction if (required.isEmpty()) { + // DAAT is the only choice return excl(opt(optional, minShouldMatch, disableCoord), prohibited); } + Scorer req; + Class scorerNotClass = scorerClassForNot(requiredCost, prohibited.size(), prohibitedCost); + if (scorerNotClass == BooleanArrayScorer.class) { + req = new BooleanArrayScorer(this, !optional.isEmpty() || disableCoord, 0, + required, Collections. emptyList(), prohibited, maxCoord); + } else if (scorerNotClass == BooleanLinkedScorer.class) { + req = new BooleanLinkedScorer(this, !optional.isEmpty() || disableCoord, 0, + required, Collections. emptyList(), prohibited, maxCoord); + } else { + req = excl(req(required, !optional.isEmpty() || disableCoord), prohibited); + } + + // required is not empty + // pure conjunction + if (optional.isEmpty()) { + return req; + } + + requiredCost *= Math.pow(0.97, prohibited.size()); // conjunction-disjunction mix: // we create the required and optional pieces with coord disabled, and then // combine the two: if minNrShouldMatch > 0, then its a conjunction: because the // optional side must match. otherwise its required + optional, factoring the // number of optional terms into the coord calculation - - Scorer req = excl(req(required, true), prohibited); - Scorer opt = opt(optional, minShouldMatch, true); - - // TODO: clean this up: its horrible - if (disableCoord) { - if (minShouldMatch > 0) { - return new ConjunctionScorer(this, new Scorer[] { req, opt }, 1F); - } else { - return new ReqOptSumScorer(req, opt); - } - } else if (optional.size() == 1) { - if (minShouldMatch > 0) { - return new ConjunctionScorer(this, new Scorer[] { req, opt }, coord(required.size()+1, maxCoord)); - } else { - float coordReq = coord(required.size(), maxCoord); - float coordBoth = coord(required.size() + 1, maxCoord); - return new BooleanTopLevelScorers.ReqSingleOptScorer(req, opt, coordReq, coordBoth); - } - } else { - if (minShouldMatch > 0) { - return new BooleanTopLevelScorers.CoordinatingConjunctionScorer(this, coords(), req, required.size(), opt); - } else { - return new BooleanTopLevelScorers.ReqMultiOptScorer(req, opt, required.size(), coords()); - } + Scorer opt; + Class scorerOrClass = scorerClassForOr(requiredCost, optional.size(), optionalCost); + if (scorerOrClass == BooleanArrayScorer.class) { + return new BooleanArrayScorer(this, disableCoord, minShouldMatch, + required, optional, Collections. emptyList(), maxCoord); } + if (scorerOrClass == BooleanLinkedScorer.class) { + return new BooleanLinkedScorer(this, disableCoord, minShouldMatch, + required, optional, Collections. emptyList(), maxCoord); + } + opt = opt(optional, minShouldMatch, true); + + return getMixScorer(required, optional, req, opt, minShouldMatch); +// // TODO: clean this up: its horrible +// if (disableCoord) { +// if (minShouldMatch > 0) { +// return new ConjunctionScorer(this, new Scorer[] { req, opt }, 1F); +// } else { +// return new ReqOptSumScorer(req, opt); +// } +// } else if (optional.size() == 1) { +// if (minShouldMatch > 0) { +// return new ConjunctionScorer(this, new Scorer[] { req, opt }, coord(required.size()+1, maxCoord)); +// } else { +// float coordReq = coord(required.size(), maxCoord); +// float coordBoth = coord(required.size() + 1, maxCoord); +// return new BooleanTopLevelScorers.ReqSingleOptScorer(req, opt, coordReq, coordBoth); +// } +// } else { +// if (minShouldMatch > 0) { +// return new BooleanTopLevelScorers.CoordinatingConjunctionScorer(this, coords(), req, required.size(), opt); +// } else { +// return new BooleanTopLevelScorers.ReqMultiOptScorer(req, opt, required.size(), coords()); +// } +// } } @Override diff --git a/lucene/core/src/java/org/apache/lucene/search/BooleanScorer.java b/lucene/core/src/java/org/apache/lucene/search/BooleanScorer.java index 173bb44..c08efd5 100644 --- a/lucene/core/src/java/org/apache/lucene/search/BooleanScorer.java +++ b/lucene/core/src/java/org/apache/lucene/search/BooleanScorer.java @@ -61,7 +61,7 @@ import org.apache.lucene.search.BooleanQuery.BooleanWeight; final class BooleanScorer extends BulkScorer { - private static final class BooleanScorerCollector extends SimpleCollector { + private final class BooleanScorerCollector extends SimpleCollector { private BucketTable bucketTable; private int mask; private Scorer scorer; @@ -77,18 +77,24 @@ final class BooleanScorer extends BulkScorer { final int i = doc & BucketTable.MASK; final Bucket bucket = table.buckets[i]; + final int coord = (mask & REQUIRED_MASK) == REQUIRED_MASK ? requiredNrMatch : 1; if (bucket.doc != doc) { // invalid bucket bucket.doc = doc; // set doc - bucket.score = scorer.score(); // initialize score + bucket.requiredScore = 0; // initialize required score + bucket.optionalScore = 0; // initialize optional score bucket.bits = mask; // initialize mask - bucket.coord = 1; // initialize coord - + bucket.coord = coord; // initialize coord bucket.next = table.first; // push onto valid list table.first = bucket; } else { // valid bucket - bucket.score += scorer.score(); // increment score bucket.bits |= mask; // add bits in mask - bucket.coord++; // increment coord + bucket.coord += coord; // increment coord + } + + if ((mask & REQUIRED_MASK) == REQUIRED_MASK) { // Required doc + bucket.requiredScore += scorer.score(); + } else if (mask == 0) { // Optional doc + bucket.optionalScore += scorer.score(); } } @@ -106,17 +112,18 @@ final class BooleanScorer extends BulkScorer { static final class Bucket { int doc = -1; // tells if bucket is valid - double score; // incremental score - // TODO: break out bool anyProhibited, int - // numRequiredMatched; then we can remove 32 limit on - // required clauses + // score is divided into RS and OS, so that its calculating order + // can be the same as the schedule of DAAT. + double requiredScore; // incremental required score + double optionalScore; // incremental optional score + int bits; // used for bool constraints int coord; // count of terms in score Bucket next; // next valid bucket } /** A simple hash table of document scores within a range. */ - static final class BucketTable { + final class BucketTable { public static final int SIZE = 1 << 11; public static final int MASK = SIZE - 1; @@ -140,8 +147,7 @@ final class BooleanScorer extends BulkScorer { static final class SubScorer { public BulkScorer scorer; - // TODO: re-enable this if BQ ever sends us required clauses - //public boolean required = false; + public boolean required = false; public boolean prohibited; public LeafCollector collector; public SubScorer next; @@ -149,13 +155,9 @@ final class BooleanScorer extends BulkScorer { public SubScorer(BulkScorer scorer, boolean required, boolean prohibited, LeafCollector collector, SubScorer next) { - if (required) { - throw new IllegalArgumentException("this scorer cannot handle required=true"); - } this.scorer = scorer; this.more = true; - // TODO: re-enable this if BQ ever sends us required clauses - //this.required = required; + this.required = required; this.prohibited = prohibited; this.collector = collector; this.next = next; @@ -165,20 +167,32 @@ final class BooleanScorer extends BulkScorer { private SubScorer scorers = null; private BucketTable bucketTable = new BucketTable(); private final float[] coordFactors; - // TODO: re-enable this if BQ ever sends us required clauses - //private int requiredMask = 0; + // minNrShouldMatch only applies to SHOULD clauses private final int minNrShouldMatch; private int end; private Bucket current; // Any time a prohibited clause matches we set bit 0: private static final int PROHIBITED_MASK = 1; + // Any time a required clause matches we set bit 1: + private static final int REQUIRED_MASK = 2; + // requiredNrMatch applies to MUST clauses + private final int requiredNrMatch; private final Weight weight; - BooleanScorer(BooleanWeight weight, boolean disableCoord, int minNrShouldMatch, - List optionalScorers, List prohibitedScorers, int maxCoord) throws IOException { + BooleanScorer(BooleanWeight weight, boolean disableCoord, int minNrShouldMatch, + List requiredScorers, List optionalScorers, List prohibitedScorers, + int maxCoord) throws IOException { + this.minNrShouldMatch = minNrShouldMatch; this.weight = weight; + + this.requiredNrMatch = requiredScorers.size(); + if (this.requiredNrMatch > 0) { + BulkScorer requiredScorer = new Weight.DefaultBulkScorer(new ConjunctionScorer( + this.weight, requiredScorers.toArray(new Scorer[requiredScorers.size()]))); + scorers = new SubScorer(requiredScorer, true, false, bucketTable.newCollector(REQUIRED_MASK), scorers); + } for (BulkScorer scorer : optionalScorers) { scorers = new SubScorer(scorer, false, false, bucketTable.newCollector(0), scorers); @@ -188,7 +202,7 @@ final class BooleanScorer extends BulkScorer { scorers = new SubScorer(scorer, false, true, bucketTable.newCollector(PROHIBITED_MASK), scorers); } - coordFactors = new float[optionalScorers.size() + 1]; + coordFactors = new float[requiredScorers.size() + optionalScorers.size() + 1]; for (int i = 0; i < coordFactors.length; i++) { coordFactors[i] = disableCoord ? 1.0f : weight.coord(i, maxCoord); } @@ -209,12 +223,9 @@ final class BooleanScorer extends BulkScorer { while (current != null) { // more queued // check prohibited & required - if ((current.bits & PROHIBITED_MASK) == 0) { + if ((current.bits & PROHIBITED_MASK) == 0 && + (requiredNrMatch == 0 || (current.bits & REQUIRED_MASK) == REQUIRED_MASK)) { - // TODO: re-enable this if BQ ever sends us required - // clauses - //&& (current.bits & requiredMask) == requiredMask) { - // NOTE: Lucene always passes max = // Integer.MAX_VALUE today, because we never embed // a BooleanScorer inside another (even though @@ -229,8 +240,10 @@ final class BooleanScorer extends BulkScorer { continue; } - if (current.coord >= minNrShouldMatch) { - fs.score = (float) (current.score * coordFactors[current.coord]); + if (current.coord - requiredNrMatch >= minNrShouldMatch) { + // Cast required score and optional score to float, + // in order to make the calculating procedure is the same as DAAT. + fs.score = (float) (((float) current.requiredScore + (float) current.optionalScore) * coordFactors[current.coord]); fs.doc = current.doc; fs.freq = current.coord; collector.collect(current.doc); diff --git a/lucene/core/src/java/org/apache/lucene/search/BooleanScorerInOrder.java b/lucene/core/src/java/org/apache/lucene/search/BooleanScorerInOrder.java new file mode 100644 index 0000000..2d31ccd --- /dev/null +++ b/lucene/core/src/java/org/apache/lucene/search/BooleanScorerInOrder.java @@ -0,0 +1,324 @@ +package org.apache.lucene.search; + +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import java.io.IOException; +import java.util.List; + +import org.apache.lucene.search.BooleanQuery.BooleanWeight; +import org.apache.lucene.util.FixedBitSet; + +/** + * This is an improvement of {@link BooleanScorer}. + * It supports all cases even when there're MUST clauses. + */ +final class BooleanScorerInOrder extends Scorer { + + + private static final class Bucket { + int doc; // doc id + // score is divided into RS and OS, so that its calculating order + // can be the same as the schedule of DAAT. + double requiredScore; // incremental required score + double optionalScore; // incremental optional score + int coord; // count of terms in score + } + + /** A simple hash table of document scores within a range. */ + private final class BucketTable { + static final int SIZE = 1 << 11; + static final int MASK = SIZE - 1; + + private FixedBitSet bitSet = new FixedBitSet(SIZE); + private boolean bitSetIsClean = true; + private final Bucket[] buckets = new Bucket[SIZE]; + // After collecting more documents, if there are more documents not collected. + boolean more = true; + int end = 0; + + BucketTable() { + // Pre-fill to save the lazy init when collecting + // each sub: + for(int idx=0;idx= 0) { + final Bucket bucket = buckets[bucketIndex]; + int prohibitedDocID = prohibitedScorer.docID(); + if (prohibitedDocID < bucket.doc) { + // According to the definition of .advance(), + // prohibitedDocID should be less then bucket.doc, + // before calling .advance(bucket.doc); + + // Skip to bucket.doc, so that prohibitedDocID >= bucket.doc + prohibitedDocID = prohibitedScorer.advance(bucket.doc); + } + + if (prohibitedDocID == DocIdSetIterator.NO_MORE_DOCS) { + break; + + } else if (prohibitedDocID == bucket.doc) { + // remove the prohibited bucket. + bitSet.clear(bucketIndex); + bucketIndex = bucketIndex + 1 < SIZE ? bitSet.nextSetBit(bucketIndex + 1) : -1; + + } else { // prohibitedDocID > bucket.doc + if (prohibitedDocID >= end) { + bucketIndex = -1; + } else { + bucketIndex = bitSet.nextSetBit(prohibitedDocID & MASK); + } + } + } + } + + // Scan optionalDocs to add coord and score. + // TODO: use countLeft to judge whether a bucket should be removed. +// int countLeft = optionalScorers.size(); + for (Scorer optionalScorer : optionalScorers) { +// countLeft --; + int bucketIndex = bitSet.nextSetBit(0); + while (bucketIndex >= 0) { + final Bucket bucket = buckets[bucketIndex]; + int optionalDocID = optionalScorer.docID(); + if (optionalDocID < bucket.doc) { + // According to the definition of .advance(), + // optionalDocID should be less then bucket.doc, + // before calling .advance(bucket.doc); + + // Skip to bucket.doc, so that optionalDocID >= bucket.doc + optionalDocID = optionalScorer.advance(bucket.doc); + } + + if (optionalDocID == DocIdSetIterator.NO_MORE_DOCS) { + break; + + } else if (optionalDocID == bucket.doc) { + bucket.coord ++; + bucket.optionalScore += optionalScorer.score(); + bucketIndex = bucketIndex + 1 < SIZE ? bitSet.nextSetBit(bucketIndex + 1) : -1; +// if (bucket.coord + countLeft < minNrShouldMatch) remove(bucket); + + } else { // optionalDocID > bucket.doc + // current bucket advances to prohibtedDocID. + if (optionalDocID >= end) { + bucketIndex = -1; + } else { + bucketIndex = bitSet.nextSetBit(optionalDocID & MASK); + } + } + } + } + + if (more && bitSet.nextSetBit(0) < 0) { + // If there are more docs not collected, but no doc is collected in this iteration, + // collect more again. + return collectMore(); + } + return more; + } + + + void advance(int target) throws IOException { + end = target & ~MASK; + // advance requiredConjunctionScorer to target doc + if (requiredConjunctionScorer.docID() < target) { + requiredConjunctionScorer.advance(target); + } + + // Scan prohibitedDocs to advance to target doc. + for (Scorer prohibitedScorer : prohibitedScorers) { + if (prohibitedScorer.docID() < target) { + prohibitedScorer.advance(target); + } + } + + // Scan optionalDocs to advance to target doc. + for (Scorer optionalScorer : optionalScorers) { + if (optionalScorer.docID() < target) { + optionalScorer.advance(target); + } + } + currentIndex = -1; + currentDoc = -1; + } + } + + private final BucketTable bucketTable = new BucketTable(); + private final float[] coordFactors; + private int currentIndex = -1; + private int currentDoc = -1; + + final private Scorer requiredConjunctionScorer; + final private List requiredScorers; + final private List optionalScorers; + final private List prohibitedScorers; + // minNrShouldMatch only applies to SHOULD clauses + final private int minNrShouldMatch; + + BooleanScorerInOrder(BooleanWeight weight, boolean disableCoord, int minNrShouldMatch, + List requiredScorers, List optionalScorers, List prohibitedScorers, + int maxCoord) throws IOException { + super(weight); + if (requiredScorers.size() == 0) { + throw new IllegalArgumentException("requriedScorers.size() must be > 0"); + } + if (minNrShouldMatch < 0) { + throw new IllegalArgumentException("Minimum number of optional scorers should not be negative"); + } + this.minNrShouldMatch = minNrShouldMatch; + + this.requiredScorers = requiredScorers; + this.optionalScorers = optionalScorers; + this.prohibitedScorers = prohibitedScorers; + this.requiredConjunctionScorer = new ConjunctionScorer( + this.weight, this.requiredScorers.toArray(new Scorer[this.requiredScorers.size()])); + + coordFactors = new float[requiredScorers.size() + optionalScorers.size() + 1]; + for (int i = 0; i < coordFactors.length; i++) { + coordFactors[i] = disableCoord ? 1.0f : weight.coord(i, maxCoord); + } + } + + @Override + public String toString() { + StringBuffer buffer = new StringBuffer(); + buffer.append("boolean("); + for (Scorer requiredScorer : requiredScorers) { + buffer.append("+"); + buffer.append(requiredScorer.toString()); + buffer.append(" "); + } + for (Scorer optionalScorer : optionalScorers) { + buffer.append(optionalScorer.toString()); + buffer.append(" "); + } + for (Scorer prohibitedScorer : prohibitedScorers) { + buffer.append("-"); + buffer.append(prohibitedScorer.toString()); + buffer.append(" "); + } + // Delete the last whitespace + if (buffer.length() > "boolean(".length()) { + buffer.deleteCharAt(buffer.length() - 1); + } + buffer.append(")"); + return buffer.toString(); + } + + @Override + public float score() throws IOException { + if (currentIndex < 0) return Float.NaN; + final Bucket bucket = bucketTable.buckets[currentIndex]; + // Cast required score and optional score to float, + // in order to make the calculating procedure is the same as DAAT. + return (float) (((float) bucket.requiredScore + (float) bucket.optionalScore) * coordFactors[bucket.coord]); + } + + @Override + public int freq() throws IOException { + return currentIndex >= 0 ? bucketTable.buckets[currentIndex].coord : 0; + } + + @Override + public int docID() { + return currentDoc; + } + + @Override + public int nextDoc() throws IOException { + if (currentIndex >= 0) { + currentIndex = currentIndex + 1 < BucketTable.SIZE ? bucketTable.bitSet.nextSetBit(currentIndex + 1) : -1; + } + if (bucketTable.more && currentIndex < 0) { + bucketTable.collectMore(); + currentIndex = bucketTable.bitSet.nextSetBit(0); + } + + while (currentIndex >= 0) { + final Bucket bucket = bucketTable.buckets[currentIndex]; + if (bucket.coord - requiredScorers.size() >= minNrShouldMatch) { + return currentDoc = bucket.doc; + + } else { + currentIndex = currentIndex + 1 < BucketTable.SIZE ? bucketTable.bitSet.nextSetBit(currentIndex + 1) : -1; + } + + if (bucketTable.more && currentIndex < 0) { + bucketTable.collectMore(); + currentIndex = bucketTable.bitSet.nextSetBit(0); + } + } + return currentDoc = DocIdSetIterator.NO_MORE_DOCS; + } + + @Override + public int advance(int target) throws IOException { + // TODO: should use trickier method to skip documents. + if (target < bucketTable.end) { + return slowAdvance(target); + + } else { + bucketTable.advance(target); + return slowAdvance(target); + } + } + + @Override + public long cost() { + return requiredConjunctionScorer.cost(); + } +} diff --git a/lucene/core/src/java/org/apache/lucene/search/ConjunctionScorer.java b/lucene/core/src/java/org/apache/lucene/search/ConjunctionScorer.java index 3e81187..e303ed8 100644 --- a/lucene/core/src/java/org/apache/lucene/search/ConjunctionScorer.java +++ b/lucene/core/src/java/org/apache/lucene/search/ConjunctionScorer.java @@ -102,12 +102,11 @@ class ConjunctionScorer extends Scorer { @Override public float score() throws IOException { - // TODO: sum into a double and cast to float if we ever send required clauses to BS1 - float sum = 0.0f; + double sum = 0.0f; for (DocsAndFreqs docs : docsAndFreqs) { sum += docs.scorer.score(); } - return sum * coord; + return (float) (sum * coord); } @Override diff --git a/lucene/core/src/test/org/apache/lucene/search/TestBooleanScorer.java b/lucene/core/src/test/org/apache/lucene/search/TestBooleanScorer.java index 358a513..69d71af 100644 --- a/lucene/core/src/test/org/apache/lucene/search/TestBooleanScorer.java +++ b/lucene/core/src/test/org/apache/lucene/search/TestBooleanScorer.java @@ -96,7 +96,7 @@ public class TestBooleanScorer extends LuceneTestCase { } }}; - BooleanScorer bs = new BooleanScorer(weight, false, 1, Arrays.asList(scorers), Collections.emptyList(), scorers.length); + BooleanScorer bs = new BooleanScorer(weight, false, 1, Collections.emptyList(), Arrays.asList(scorers), Collections.emptyList(), scorers.length); final List hits = new ArrayList<>(); bs.score(new SimpleCollector() { diff --git a/lucene/core/src/test/org/apache/lucene/search/TestBooleanUnevenly.java b/lucene/core/src/test/org/apache/lucene/search/TestBooleanUnevenly.java new file mode 100644 index 0000000..8b11da1 --- /dev/null +++ b/lucene/core/src/test/org/apache/lucene/search/TestBooleanUnevenly.java @@ -0,0 +1,132 @@ +package org.apache.lucene.search; + +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import org.apache.lucene.analysis.MockAnalyzer; +import org.apache.lucene.document.Document; +import org.apache.lucene.document.Field; +import org.apache.lucene.index.IndexReader; +import org.apache.lucene.index.RandomIndexWriter; +import org.apache.lucene.index.Term; +import org.apache.lucene.store.Directory; +import org.apache.lucene.util.LuceneTestCase; +import org.junit.AfterClass; +import org.junit.BeforeClass; +import org.junit.Test; + +/** + * BooleanQuery.scorer should be tested, when hit documents + * are very unevenly distributed. + */ +public class TestBooleanUnevenly extends LuceneTestCase { + private static IndexSearcher searcher; + private static IndexReader reader; + + public static final String field = "field"; + private static Directory directory; + + private static int count1 = 0; + + @BeforeClass + public static void beforeClass() throws Exception { + directory = newDirectory(); + RandomIndexWriter w = new RandomIndexWriter(random(), directory, new MockAnalyzer(random())); + Document doc; + for (int i=0;i<2;i++) { + for (int j=0;j<2048;j++) { + doc = new Document(); + doc.add(newTextField(field, "1", Field.Store.NO)); + count1 ++; + w.addDocument(doc); + } + for (int j=0;j<2048;j++) { + doc = new Document(); + doc.add(newTextField(field, "2", Field.Store.NO)); + w.addDocument(doc); + } + doc = new Document(); + doc.add(newTextField(field, "1", Field.Store.NO)); + count1 ++; + w.addDocument(doc); + for (int j=0;j<2048;j++) { + doc = new Document(); + doc.add(newTextField(field, "2", Field.Store.NO)); + w.addDocument(doc); + } + } + reader = w.getReader(); + searcher = newSearcher(reader); + w.shutdown(); + } + + @AfterClass + public static void afterClass() throws Exception { + reader.close(); + directory.close(); + searcher = null; + reader = null; + directory = null; + } + + @Test + public void testQueries01() throws Exception { + BooleanQuery query = new BooleanQuery(); + query.add(new TermQuery(new Term(field, "1")), BooleanClause.Occur.MUST); + query.add(new TermQuery(new Term(field, "1")), BooleanClause.Occur.SHOULD); + query.add(new TermQuery(new Term(field, "2")), BooleanClause.Occur.SHOULD); + + TopScoreDocCollector collector = TopScoreDocCollector.create(1000, false); + searcher.search(query, null, collector); + TopDocs tops1 = collector.topDocs(); + ScoreDoc[] hits1 = tops1.scoreDocs; + int hitsNum1 = tops1.totalHits; + + collector = TopScoreDocCollector.create(1000, true); + searcher.search(query, null, collector); + TopDocs tops2 = collector.topDocs(); + ScoreDoc[] hits2 = tops2.scoreDocs; + int hitsNum2 = tops2.totalHits; + + assertEquals(hitsNum1, count1); + assertEquals(hitsNum2, count1); + CheckHits.checkEqual(query, hits1, hits2); + } + + @Test + public void testQueries02() throws Exception { + BooleanQuery query = new BooleanQuery(); + query.add(new TermQuery(new Term(field, "1")), BooleanClause.Occur.SHOULD); + query.add(new TermQuery(new Term(field, "1")), BooleanClause.Occur.SHOULD); + + TopScoreDocCollector collector = TopScoreDocCollector.create(1000, false); + searcher.search(query, null, collector); + TopDocs tops1 = collector.topDocs(); + ScoreDoc[] hits1 = tops1.scoreDocs; + int hitsNum1 = tops1.totalHits; + + collector = TopScoreDocCollector.create(1000, true); + searcher.search(query, null, collector); + TopDocs tops2 = collector.topDocs(); + ScoreDoc[] hits2 = tops2.scoreDocs; + int hitsNum2 = tops2.totalHits; + + assertEquals(hitsNum1, count1); + assertEquals(hitsNum2, count1); + CheckHits.checkEqual(query, hits1, hits2); + } +}