Index: src/java/org/apache/lucene/search/FlagCombiningQuery.java =================================================================== --- src/java/org/apache/lucene/search/FlagCombiningQuery.java (revision 0) +++ src/java/org/apache/lucene/search/FlagCombiningQuery.java (revision 0) @@ -0,0 +1,183 @@ +package org.apache.lucene.search; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.List; +import java.util.Set; + +import org.apache.lucene.index.IndexReader; +import org.apache.lucene.search.Explanation; +import org.apache.lucene.search.Query; +import org.apache.lucene.search.Scorer; +import org.apache.lucene.search.Searcher; +import org.apache.lucene.search.Similarity; +import org.apache.lucene.search.Weight; + +/** + * Adjusts scores of contained query to include match flags provided by FlagRecordingQuery objects + * nested in the query + * @author Mark + * + */ +public class FlagCombiningQuery extends Query +{ + private static final long serialVersionUID = 1L; + Query childQuery; + public static int[]flagMasks={1,1<<1,1<<2,1<<3,1<<4,1<<5,1<<6,1<<7}; + + List matchFlaggers=new ArrayList(); + static final float DEFAULT_MULTIPLIER=1000f; + float multiplier=DEFAULT_MULTIPLIER; + + FlagRecordingQuery flaggersArray[]=null; + + public FlagCombiningQuery(Query childQuery) + { + super(); + this.childQuery = childQuery; + } + + public void extractTerms(Set terms) + { + childQuery.extractTerms(terms); + } + + public Query rewrite(IndexReader reader) throws IOException + { + flaggersArray=(FlagRecordingQuery[]) matchFlaggers.toArray(new FlagRecordingQuery[matchFlaggers.size()]); + childQuery=childQuery.rewrite(reader); + return this; + } + + public String toString(String field) + { + return "QueryMatchMonitor ("+childQuery.toString()+")"; + } + + public Weight createWeight(Searcher searcher) throws IOException + { + return new FlagCombiningQueryWeight(childQuery.weight(searcher)); + } + + class FlagCombiningQueryWeight extends Weight + { + Weight delegateWeight; + + public FlagCombiningQueryWeight(Weight weight) + { + super(); + delegateWeight = weight; + } + + @Override + public Query getQuery() + { + return delegateWeight.getQuery(); + } + + @Override + public float getValue() + { + return delegateWeight.getValue(); + } + + @Override + public float sumOfSquaredWeights() throws IOException + { + return delegateWeight.sumOfSquaredWeights(); + } + + @Override + public void normalize(float norm) + { + delegateWeight.normalize(norm); + } + + @Override + public Scorer scorer(IndexReader reader, boolean scoreDocsInOrder, + boolean topScorer) throws IOException + { + Scorer delegateScorer=delegateWeight.scorer(reader,scoreDocsInOrder,topScorer); + return new FlagCombiningQueryScorer(delegateScorer.getSimilarity(),delegateScorer); + } + + @Override + public Explanation explain(IndexReader reader, int doc) throws IOException + { + return delegateWeight.explain(reader, doc); + } + + } + class FlagCombiningQueryScorer extends Scorer + { + private Scorer delegateScorer; + + public FlagCombiningQueryScorer(Similarity similarity, + Scorer delegateScorer) + { + super(similarity); + this.delegateScorer=delegateScorer; + } + + @Override + public float score() throws IOException + { + float score = delegateScorer.score(); + int flags=0; + int d=docID(); + for (FlagRecordingQuery frq : flaggersArray) + { + if(frq.matched(d)) + { + int mask=flagMasks[frq.flag-1]; + flags|=mask; + } + } + + //Multiply score to turn float into int with sufficient fractions in score. + int shiftedI=(int) (score*multiplier); + //Shift int to make space for byte holding flags + int iPlusSpaceForByte=shiftedI<<8; + //Add match flags + int iCombinedScoreAndFlags=iPlusSpaceForByte|flags; + return iCombinedScoreAndFlags; + } + + + @Override + public int docID() + { + return delegateScorer.docID(); + } + + @Override + public int nextDoc() throws IOException + { + return delegateScorer.nextDoc(); + } + + @Override + public int advance(int doc) throws IOException + { + return delegateScorer.advance(doc); + } + + } + protected void add(FlagRecordingQuery flagRecordingQuery) + { + matchFlaggers.add(flagRecordingQuery); + } + + public static boolean hasFlag(int flagNum, float score) + { + return hasFlag(flagNum, score, DEFAULT_MULTIPLIER); + } + + public static boolean hasFlag(int flagNum, float score, float multiplier) + { + int iCombinedScoreAndFlags=(int) score; + int flagsInt=iCombinedScoreAndFlags&0x00ff; + return (flagsInt&flagMasks[flagNum-1])!=0; + } + +} Index: src/java/org/apache/lucene/search/FlagRecordingQuery.java =================================================================== --- src/java/org/apache/lucene/search/FlagRecordingQuery.java (revision 0) +++ src/java/org/apache/lucene/search/FlagRecordingQuery.java (revision 0) @@ -0,0 +1,170 @@ +package org.apache.lucene.search; + +import java.io.IOException; +import java.util.Set; + +import org.apache.lucene.index.IndexReader; +/** + * Class used to record flag for matches on child queries. These flags are combined into bits encoded + * in the score + * + * @author Mark + * + */ +public class FlagRecordingQuery extends Query +{ + Query childQuery; + private int currDoc=-1; + private int lastDoc=-1; + private float score; + protected int flag; + + public FlagRecordingQuery(FlagCombiningQuery fcq,Query childQuery, int flag) + { + super(); + this.flag=flag; + this.childQuery = childQuery; + fcq.add(this); + } + + public boolean matched(int doc) + { +// System.out.println("match check: "+ childQuery+" on doc#"+doc+ " when at "+this.currDoc); + return (this.currDoc==doc)||(this.lastDoc==doc); + } + + + public void extractTerms(Set terms) + { + childQuery.extractTerms(terms); + } + + public Query rewrite(IndexReader reader) throws IOException + { + //TODO clone so QueryMatchMonitor is rewritable + childQuery=childQuery.rewrite(reader); + return this; + } + + public String toString(String field) + { + return "QueryMatchMonitor ("+childQuery.toString()+")"; + } + + public Weight createWeight(Searcher searcher) throws IOException + { + return new FlagRecordingQueryWeight(childQuery.weight(searcher)); + } + + class FlagRecordingQueryWeight extends Weight + { + /** + * + */ + private static final long serialVersionUID = 1L; + Weight delegateWeight; + + public FlagRecordingQueryWeight(Weight weight) + { + super(); + delegateWeight = weight; + } + + @Override + public Query getQuery() + { + return delegateWeight.getQuery(); + } + + @Override + public float getValue() + { + return delegateWeight.getValue(); + } + + @Override + public float sumOfSquaredWeights() throws IOException + { + return delegateWeight.sumOfSquaredWeights(); + } + + @Override + public void normalize(float norm) + { + delegateWeight.normalize(norm); + } + + @Override + public Scorer scorer(IndexReader reader, boolean scoreDocsInOrder, + boolean topScorer) throws IOException + { + currDoc=-1; + lastDoc=-1; + Scorer delegateScorer=delegateWeight.scorer(reader, scoreDocsInOrder,topScorer); + return new FlagRecordingQueryScorer(delegateScorer.getSimilarity(),delegateScorer); + } + + @Override + public Explanation explain(IndexReader reader, int doc) throws IOException + { + return delegateWeight.explain(reader, doc); + } + } + class FlagRecordingQueryScorer extends Scorer + { + private Scorer delegateScorer; + + public FlagRecordingQueryScorer(Similarity similarity, + Scorer delegateScorer) + { + super(similarity); + this.delegateScorer=delegateScorer; + } + + @Override + public Explanation explain(int doc) throws IOException + { + return delegateScorer.explain(doc); + } + + @Override + public float score() throws IOException + { +// System.out.println("get score doc#"+doc()+" on q="+childQuery); + FlagRecordingQuery.this.score= delegateScorer.score(); + return score; + } + + @Override + public int docID() + { + return currDoc; + } + + @Override + public int nextDoc() throws IOException + { + int nextDoc= delegateScorer.nextDoc(); + if(nextDoc!=NO_MORE_DOCS) + { + lastDoc=currDoc; + currDoc=nextDoc; + } + return nextDoc; + } + + @Override + public int advance(int doc) throws IOException + { + int nextDoc= delegateScorer.advance(doc); + if(nextDoc!=NO_MORE_DOCS) + { + lastDoc=currDoc; + currDoc=nextDoc; + } + return nextDoc; + } + + } + +} Index: src/test/org/apache/lucene/search/TestFlaggingQuery.java =================================================================== --- src/test/org/apache/lucene/search/TestFlaggingQuery.java (revision 0) +++ src/test/org/apache/lucene/search/TestFlaggingQuery.java (revision 0) @@ -0,0 +1,172 @@ +package org.apache.lucene.search; +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +import java.io.IOException; +import java.util.HashMap; + +import org.apache.lucene.analysis.WhitespaceAnalyzer; +import org.apache.lucene.document.Document; +import org.apache.lucene.document.Field; +import org.apache.lucene.document.NumericField; +import org.apache.lucene.index.CorruptIndexException; +import org.apache.lucene.index.IndexReader; +import org.apache.lucene.index.IndexWriter; +import org.apache.lucene.index.Term; +import org.apache.lucene.index.IndexWriter.MaxFieldLength; +import org.apache.lucene.search.BooleanClause.Occur; +import org.apache.lucene.store.RAMDirectory; +import org.apache.lucene.util.LuceneTestCase; + +/** + * Demonstrates ability to determine matches on any type of query clauses for highlighting + * purposes. This demo include NumericRangeQueries and cached Term filters + * + * + * @author Mark + * + */ +public class TestFlaggingQuery extends LuceneTestCase +{ + + + + public TestFlaggingQuery(String name) + { + super(name); + // TODO Auto-generated constructor stub + } + + HashMap cachedCommonFilters=new HashMap(); + + @Override + protected void setUp() throws Exception + { + cacheCommonTermFilter("transmission","manual"); + cacheCommonTermFilter("transmission","automatic"); + super.setUp(); + } + + + private void cacheCommonTermFilter(String field, String value) + { + Term term=new Term(field,value); + TermsFilter tf = new TermsFilter(); + tf.addTerm(term); + CachingWrapperFilter cachedFilter=new CachingWrapperFilter(tf); + cachedCommonFilters.put(term, cachedFilter); + } + + + public void testQuery() throws Exception + { + RAMDirectory rd=new RAMDirectory(); + IndexWriter w=new IndexWriter(rd,new WhitespaceAnalyzer(),true,MaxFieldLength.UNLIMITED); + + + + addCarDoc(w,new Car(2000, 50000, 4, "automatic","ford", "focus")); + addCarDoc(w,new Car(1000, 80000, 4, "manual","volkswagen", "golf")); + addCarDoc(w,new Car(19000, 32000, 4, "manual","porsche", "boxster")); + w.commit(); + w.close(); + + IndexSearcher s=new IndexSearcher(rd,true); + + BooleanQuery bq=new BooleanQuery(); + + int queryStartPrice=2000; + int queryEndPrice=10000; + String queryTransmissionType="automatic"; + int queryStartMileage=20000; + int queryEndMileage=90000; + + Query priceQuery = NumericRangeQuery.newIntRange("price", 4, queryStartPrice,queryEndPrice, true, true); + Query transmissionQuery=new ConstantScoreQuery(cachedCommonFilters.get(new Term("transmission",queryTransmissionType))); + Query mileageQuery = NumericRangeQuery.newIntRange("mileage", 4, queryStartMileage,queryEndMileage, true, true); + + FlagCombiningQuery fcq=new FlagCombiningQuery(bq); + int FLAG_PRICE_MATCH=1; + int FLAG_TRANSMISSION_MATCH=2; + int FLAG_MILEAGE_MATCH=3; + + FlagRecordingQuery frqPrice=new FlagRecordingQuery(fcq,priceQuery,FLAG_PRICE_MATCH); + FlagRecordingQuery frqTransmission=new FlagRecordingQuery(fcq,transmissionQuery,FLAG_TRANSMISSION_MATCH); + FlagRecordingQuery frqMileage=new FlagRecordingQuery(fcq,mileageQuery,FLAG_MILEAGE_MATCH); + bq.add(new BooleanClause(frqPrice,Occur.SHOULD)); + bq.add(new BooleanClause(frqTransmission,Occur.SHOULD)); + bq.add(new BooleanClause(frqMileage,Occur.SHOULD)); + + TopDocs td = s.search(fcq,10); + ScoreDoc[] sd = td.scoreDocs; + IndexReader reader=s.getIndexReader(); + for (ScoreDoc scoreDoc : sd) + { + float score=scoreDoc.score; + Document doc=reader.document(scoreDoc.doc); + int price=Integer.parseInt(doc.get("price")); + int mileage=Integer.parseInt(doc.get("mileage")); + String transmission=doc.get("transmission"); + + assertEquals("Price "+price+" ["+queryStartPrice+"-"+queryEndPrice+"]="+ + FlagCombiningQuery.hasFlag(FLAG_PRICE_MATCH, score), + ((price>=queryStartPrice)&&(price<=queryEndPrice)), + FlagCombiningQuery.hasFlag(FLAG_PRICE_MATCH, score)); + assertEquals("Mileage matched", ((mileage>=queryStartMileage)&&(price<=queryEndMileage)), + FlagCombiningQuery.hasFlag(FLAG_MILEAGE_MATCH, score)); + assertEquals("Mileage matched", transmission.equals(queryTransmissionType), + FlagCombiningQuery.hasFlag(FLAG_TRANSMISSION_MATCH, score)); + } + + s.close(); + + } + static class Car + { + int price; //possibly large range values + int mileage; //possibly large range values + int numDoors; //Small number of values + String transmission; //only one of two values - best cached + String make; + String model; + public Car(int price, int mileage, int numDoors, String transmission, + String make, String model) + { + super(); + this.price = price; + this.mileage = mileage; + this.numDoors = numDoors; + this.transmission = transmission; + this.make = make; + this.model = model; + } + + } + + private void addCarDoc(IndexWriter w, Car car) throws CorruptIndexException, IOException + { + Document doc=new Document(); + doc.add(new NumericField("price", 4,Field.Store.YES, true).setIntValue(car.price)); + doc.add(new NumericField("mileage", 4,Field.Store.YES, true).setIntValue(car.mileage)); + doc.add(new Field("transmission",car.transmission,Field.Store.YES, Field.Index.ANALYZED)); + doc.add(new Field("make",car.make,Field.Store.YES, Field.Index.ANALYZED)); + doc.add(new Field("model",car.model,Field.Store.YES, Field.Index.ANALYZED)); + w.addDocument(doc); + } + + + +}