diff --git a/lucene/classification/src/java/org/apache/lucene/classification/utils/NearestFuzzyQuery.java b/lucene/classification/src/java/org/apache/lucene/classification/utils/NearestFuzzyQuery.java index d286eb9351..c9db3507f4 100644 --- a/lucene/classification/src/java/org/apache/lucene/classification/utils/NearestFuzzyQuery.java +++ b/lucene/classification/src/java/org/apache/lucene/classification/utils/NearestFuzzyQuery.java @@ -39,6 +39,7 @@ import org.apache.lucene.search.BoostQuery; import org.apache.lucene.search.FuzzyTermsEnum; import org.apache.lucene.search.MaxNonCompetitiveBoostAttribute; import org.apache.lucene.search.Query; +import org.apache.lucene.search.QueryVisitor; import org.apache.lucene.search.TermQuery; import org.apache.lucene.util.AttributeSource; import org.apache.lucene.util.BytesRef; @@ -269,6 +270,13 @@ public class NearestFuzzyQuery extends Query { return bq.build(); } + @Override + public void visit(QueryVisitor visitor) { + for (FieldVals fv : fieldVals) { + visitor.visitLeaf(this, fv.fieldName, () -> t -> true); + } + } + //Holds info for a fuzzy term variant - initially score is set to edit distance (for ranking best // term variants) then is reset with IDF for use in ranking against all other // terms/fields diff --git a/lucene/core/src/java/org/apache/lucene/search/AutomatonQuery.java b/lucene/core/src/java/org/apache/lucene/search/AutomatonQuery.java index 7fb155d78e..1f32f4b5e4 100644 --- a/lucene/core/src/java/org/apache/lucene/search/AutomatonQuery.java +++ b/lucene/core/src/java/org/apache/lucene/search/AutomatonQuery.java @@ -151,7 +151,12 @@ public class AutomatonQuery extends MultiTermQuery { buffer.append("}"); return buffer.toString(); } - + + @Override + public void visit(QueryVisitor visitor) { + visitor.visitLeaf(this, field, () -> QueryVisitor.matchesAutomaton(compiled)); + } + /** Returns the automaton used to create this query */ public Automaton getAutomaton() { return automaton; diff --git a/lucene/core/src/java/org/apache/lucene/search/BlendedTermQuery.java b/lucene/core/src/java/org/apache/lucene/search/BlendedTermQuery.java index 8f85e25b44..a05e6dae0d 100644 --- a/lucene/core/src/java/org/apache/lucene/search/BlendedTermQuery.java +++ b/lucene/core/src/java/org/apache/lucene/search/BlendedTermQuery.java @@ -294,6 +294,14 @@ public final class BlendedTermQuery extends Query { return rewriteMethod.rewrite(termQueries); } + @Override + public void visit(QueryVisitor visitor) { + QueryVisitor v = visitor.getShouldMatchVisitor(this); + for (Term term : terms) { + v.visitLeaf(this, term); + } + } + private static TermStates adjustFrequencies(IndexReaderContext readerContext, TermStates ctx, int artificialDf, long artificialTtf) throws IOException { List leaves = readerContext.leaves(); diff --git a/lucene/core/src/java/org/apache/lucene/search/BooleanQuery.java b/lucene/core/src/java/org/apache/lucene/search/BooleanQuery.java index f52df9fb9c..600fa43af3 100644 --- a/lucene/core/src/java/org/apache/lucene/search/BooleanQuery.java +++ b/lucene/core/src/java/org/apache/lucene/search/BooleanQuery.java @@ -451,6 +451,25 @@ public class BooleanQuery extends Query implements Iterable { return super.rewrite(reader); } + @Override + public void visit(QueryVisitor visitor) { + for (BooleanClause clause : clauses) { + switch (clause.getOccur()) { + case MUST: + clause.getQuery().visit(visitor.getMatchingVisitor(this)); + break; + case SHOULD: + clause.getQuery().visit(visitor.getShouldMatchVisitor(this)); + break; + case FILTER: + clause.getQuery().visit(visitor.getFilteringVisitor(this)); + break; + case MUST_NOT: + clause.getQuery().visit(visitor.getNonMatchingVisitor(this)); + } + } + } + /** Prints a user-readable version of this query. */ @Override public String toString(String field) { diff --git a/lucene/core/src/java/org/apache/lucene/search/BoostQuery.java b/lucene/core/src/java/org/apache/lucene/search/BoostQuery.java index 4e4649cb71..7d9eb536e7 100644 --- a/lucene/core/src/java/org/apache/lucene/search/BoostQuery.java +++ b/lucene/core/src/java/org/apache/lucene/search/BoostQuery.java @@ -104,6 +104,11 @@ public final class BoostQuery extends Query { return super.rewrite(reader); } + @Override + public void visit(QueryVisitor visitor) { + query.visit(visitor.getMatchingVisitor(this)); + } + @Override public String toString(String field) { StringBuilder builder = new StringBuilder(); diff --git a/lucene/core/src/java/org/apache/lucene/search/ConstantScoreQuery.java b/lucene/core/src/java/org/apache/lucene/search/ConstantScoreQuery.java index 5c9ed19f89..e8c0a50df3 100644 --- a/lucene/core/src/java/org/apache/lucene/search/ConstantScoreQuery.java +++ b/lucene/core/src/java/org/apache/lucene/search/ConstantScoreQuery.java @@ -64,6 +64,11 @@ public final class ConstantScoreQuery extends Query { return super.rewrite(reader); } + @Override + public void visit(QueryVisitor visitor) { + query.visit(visitor.getFilteringVisitor(this)); + } + /** We return this as our {@link BulkScorer} so that if the CSQ * wraps a query with its own optimized top-level * scorer (e.g. BooleanScorer) we can use that diff --git a/lucene/core/src/java/org/apache/lucene/search/DisjunctionMaxQuery.java b/lucene/core/src/java/org/apache/lucene/search/DisjunctionMaxQuery.java index 43b42b575c..68da293860 100644 --- a/lucene/core/src/java/org/apache/lucene/search/DisjunctionMaxQuery.java +++ b/lucene/core/src/java/org/apache/lucene/search/DisjunctionMaxQuery.java @@ -237,6 +237,14 @@ public final class DisjunctionMaxQuery extends Query implements Iterable return super.rewrite(reader); } + @Override + public void visit(QueryVisitor visitor) { + QueryVisitor v = visitor.getShouldMatchVisitor(this); + for (Query q : disjuncts) { + q.visit(v); + } + } + /** Prettyprint us. * @param field the field to which we are applied * @return a string that shows what we do, of the form "(disjunct1 | disjunct2 | ... | disjunctn)^boost" diff --git a/lucene/core/src/java/org/apache/lucene/search/FuzzyQuery.java b/lucene/core/src/java/org/apache/lucene/search/FuzzyQuery.java index 3c1eacd80e..1df8014469 100644 --- a/lucene/core/src/java/org/apache/lucene/search/FuzzyQuery.java +++ b/lucene/core/src/java/org/apache/lucene/search/FuzzyQuery.java @@ -173,7 +173,12 @@ public class FuzzyQuery extends MultiTermQuery { buffer.append(Integer.toString(maxEdits)); return buffer.toString(); } - + + @Override + public void visit(QueryVisitor visitor) { + visitor.visitLeaf(this, field, () -> term -> true); // TODO construct automaton from FuzzyTermsEnum + } + @Override public int hashCode() { final int prime = 31; diff --git a/lucene/core/src/java/org/apache/lucene/search/IndexOrDocValuesQuery.java b/lucene/core/src/java/org/apache/lucene/search/IndexOrDocValuesQuery.java index d69421ec4d..5c5628f38c 100644 --- a/lucene/core/src/java/org/apache/lucene/search/IndexOrDocValuesQuery.java +++ b/lucene/core/src/java/org/apache/lucene/search/IndexOrDocValuesQuery.java @@ -109,6 +109,13 @@ public final class IndexOrDocValuesQuery extends Query { return this; } + @Override + public void visit(QueryVisitor visitor) { + QueryVisitor v = visitor.getMatchingVisitor(this); + indexQuery.visit(v); + dvQuery.visit(v); + } + @Override public Weight createWeight(IndexSearcher searcher, ScoreMode scoreMode, float boost) throws IOException { final Weight indexWeight = indexQuery.createWeight(searcher, scoreMode, boost); diff --git a/lucene/core/src/java/org/apache/lucene/search/MultiPhraseQuery.java b/lucene/core/src/java/org/apache/lucene/search/MultiPhraseQuery.java index c8d22baec9..1599a233bd 100644 --- a/lucene/core/src/java/org/apache/lucene/search/MultiPhraseQuery.java +++ b/lucene/core/src/java/org/apache/lucene/search/MultiPhraseQuery.java @@ -203,6 +203,17 @@ public class MultiPhraseQuery extends Query { } } + @Override + public void visit(QueryVisitor visitor) { + QueryVisitor v = visitor.getMatchingVisitor(this); + for (Term[] terms : termArrays) { + QueryVisitor sv = v.getShouldMatchVisitor(this); + for (Term term : terms) { + sv.visitLeaf(this, term); + } + } + } + @Override public Weight createWeight(IndexSearcher searcher, ScoreMode scoreMode, float boost) throws IOException { final Map termStates = new HashMap<>(); diff --git a/lucene/core/src/java/org/apache/lucene/search/MultiTermQueryConstantScoreWrapper.java b/lucene/core/src/java/org/apache/lucene/search/MultiTermQueryConstantScoreWrapper.java index 9c3721117a..b0043b634e 100644 --- a/lucene/core/src/java/org/apache/lucene/search/MultiTermQueryConstantScoreWrapper.java +++ b/lucene/core/src/java/org/apache/lucene/search/MultiTermQueryConstantScoreWrapper.java @@ -231,4 +231,9 @@ final class MultiTermQueryConstantScoreWrapper extends }; } + + @Override + public void visit(QueryVisitor visitor) { + query.visit(visitor.getFilteringVisitor(this)); + } } diff --git a/lucene/core/src/java/org/apache/lucene/search/NGramPhraseQuery.java b/lucene/core/src/java/org/apache/lucene/search/NGramPhraseQuery.java index db997d37e7..5084d4ea24 100644 --- a/lucene/core/src/java/org/apache/lucene/search/NGramPhraseQuery.java +++ b/lucene/core/src/java/org/apache/lucene/search/NGramPhraseQuery.java @@ -77,6 +77,11 @@ public class NGramPhraseQuery extends Query { return builder.build(); } + @Override + public void visit(QueryVisitor visitor) { + phraseQuery.visit(visitor.getFilteringVisitor(this)); + } + @Override public boolean equals(Object other) { return sameClassAs(other) && diff --git a/lucene/core/src/java/org/apache/lucene/search/PhraseQuery.java b/lucene/core/src/java/org/apache/lucene/search/PhraseQuery.java index 8f042716a6..bcbe6db02c 100644 --- a/lucene/core/src/java/org/apache/lucene/search/PhraseQuery.java +++ b/lucene/core/src/java/org/apache/lucene/search/PhraseQuery.java @@ -284,6 +284,14 @@ public class PhraseQuery extends Query { } } + @Override + public void visit(QueryVisitor visitor) { + QueryVisitor v = visitor.getMatchingVisitor(this); + for (Term term : terms) { + v.visitLeaf(this, term); + } + } + static class PostingsAndFreq implements Comparable { final PostingsEnum postings; final int position; diff --git a/lucene/core/src/java/org/apache/lucene/search/Query.java b/lucene/core/src/java/org/apache/lucene/search/Query.java index 54de63fc02..4767aeb9e0 100644 --- a/lucene/core/src/java/org/apache/lucene/search/Query.java +++ b/lucene/core/src/java/org/apache/lucene/search/Query.java @@ -74,6 +74,14 @@ public abstract class Query { return this; } + /** + * Recurse through the query tree, visiting any child queries + * @param visitor a QueryVisitor to be called by each query in the tree + */ + public void visit(QueryVisitor visitor) { + visitor.visitLeaf(this); + } + /** * Override and implement query instance equivalence properly in a subclass. * This is required so that {@link QueryCache} works properly. diff --git a/lucene/core/src/java/org/apache/lucene/search/QueryVisitor.java b/lucene/core/src/java/org/apache/lucene/search/QueryVisitor.java new file mode 100644 index 0000000000..c6a09253ba --- /dev/null +++ b/lucene/core/src/java/org/apache/lucene/search/QueryVisitor.java @@ -0,0 +1,123 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.lucene.search; + +import java.util.function.Predicate; +import java.util.function.Supplier; + +import org.apache.lucene.index.Term; +import org.apache.lucene.util.BytesRef; +import org.apache.lucene.util.automaton.CompiledAutomaton; + +/** + * Interface to allow recursion through a query tree + * + * @see Query#visit(QueryVisitor) + */ +public interface QueryVisitor { + + /** + * Called by leaf queries that match on a specific term + * @param query the leaf query visited + * @param term the term the query will match on + */ + void visitLeaf(Query query, Term term); + + /** + * Called by leaf queries that do not match against the terms index + * @param query the leaf query visited + */ + default void visitLeaf(Query query) {} + + /** + * Called by leaf queries that match against a set of terms defined by a predicate + * @param query the leaf query + * @param field the field the query matches against + * @param predicateSupplier a supplier for a predicate that will select matching terms + */ + default void visitLeaf(Query query, String field, Supplier> predicateSupplier) {} + + /** + * Pulls a visitor instance for visiting matching child clauses of a query + * + * The default implementation returns {@code this} + * + * @param parent the query visited + */ + default QueryVisitor getMatchingVisitor(Query parent) { + return this; + } + + /** + * Pulls a visitor instance for visiting matching 'should-match' child clauses of a query + * + * The default implementation returns {@code this} + * + * @param parent the query visited + */ + default QueryVisitor getShouldMatchVisitor(Query parent) { + return this; + } + + /** + * Pulls a visitor instance for visiting matching non-scoring child clauses of a query + * + * The default implementation returns {@code this} + * + * @param parent the query visited + */ + default QueryVisitor getFilteringVisitor(Query parent) { + return this; + } + + /** + * Pulls a visitor instance for visiting matching 'must-not' child clauses of a query + * + * The default implementation returns {@link #NO_OP} + * + * @param parent the query visited + */ + default QueryVisitor getNonMatchingVisitor(Query parent) { + return NO_OP; + } + + /** + * Builds a predicate for matching a set of bytes from a {@link CompiledAutomaton} + */ + static Predicate matchesAutomaton(CompiledAutomaton automaton) { + return term -> { + switch (automaton.type) { + case NONE: + return false; + case ALL: + return true; + case SINGLE: + assert automaton.term != null; + return automaton.term.equals(term); + } + assert automaton.runAutomaton != null; + return automaton.runAutomaton.run(term.bytes, term.offset, term.length); + }; + } + + /** + * A QueryVisitor implementation that collects no terms + */ + QueryVisitor NO_OP = (query, term) -> { }; + +} diff --git a/lucene/core/src/java/org/apache/lucene/search/SynonymQuery.java b/lucene/core/src/java/org/apache/lucene/search/SynonymQuery.java index 25205adb76..7a551ee82f 100644 --- a/lucene/core/src/java/org/apache/lucene/search/SynonymQuery.java +++ b/lucene/core/src/java/org/apache/lucene/search/SynonymQuery.java @@ -117,6 +117,14 @@ public final class SynonymQuery extends Query { return this; } + @Override + public void visit(QueryVisitor visitor) { + QueryVisitor v = visitor.getShouldMatchVisitor(this); + for (Term term : terms) { + v.visitLeaf(this, term); + } + } + @Override public Weight createWeight(IndexSearcher searcher, ScoreMode scoreMode, float boost) throws IOException { if (scoreMode.needsScores()) { diff --git a/lucene/core/src/java/org/apache/lucene/search/TermInSetQuery.java b/lucene/core/src/java/org/apache/lucene/search/TermInSetQuery.java index 814bea2638..88cae717d2 100644 --- a/lucene/core/src/java/org/apache/lucene/search/TermInSetQuery.java +++ b/lucene/core/src/java/org/apache/lucene/search/TermInSetQuery.java @@ -42,6 +42,7 @@ import org.apache.lucene.util.Accountable; import org.apache.lucene.util.ArrayUtil; import org.apache.lucene.util.BytesRef; import org.apache.lucene.util.BytesRefBuilder; +import org.apache.lucene.util.BytesRefHash; import org.apache.lucene.util.DocIdSetBuilder; import org.apache.lucene.util.RamUsageEstimator; @@ -122,6 +123,18 @@ public class TermInSetQuery extends Query implements Accountable { return super.rewrite(reader); } + @Override + public void visit(QueryVisitor visitor) { + visitor.visitLeaf(this, field, () -> { + BytesRefHash terms = new BytesRefHash(); + TermIterator it = termData.iterator(); + for (BytesRef term = it.next(); term != null; term = it.next()) { + terms.add(term); + } + return t -> terms.find(t) != -1; + }); + } + @Override public boolean equals(Object other) { return sameClassAs(other) && diff --git a/lucene/core/src/java/org/apache/lucene/search/TermQuery.java b/lucene/core/src/java/org/apache/lucene/search/TermQuery.java index 1eebbce6e9..57cee99d44 100644 --- a/lucene/core/src/java/org/apache/lucene/search/TermQuery.java +++ b/lucene/core/src/java/org/apache/lucene/search/TermQuery.java @@ -205,6 +205,11 @@ public class TermQuery extends Query { return new TermWeight(searcher, scoreMode, boost, termState); } + @Override + public void visit(QueryVisitor visitor) { + visitor.visitLeaf(this, term); + } + /** Prints a user-readable version of this query. */ @Override public String toString(String field) { diff --git a/lucene/core/src/java/org/apache/lucene/search/spans/FieldMaskingSpanQuery.java b/lucene/core/src/java/org/apache/lucene/search/spans/FieldMaskingSpanQuery.java index 4a4c4fbae9..a17e0ef304 100644 --- a/lucene/core/src/java/org/apache/lucene/search/spans/FieldMaskingSpanQuery.java +++ b/lucene/core/src/java/org/apache/lucene/search/spans/FieldMaskingSpanQuery.java @@ -17,14 +17,15 @@ package org.apache.lucene.search.spans; +import java.io.IOException; +import java.util.Objects; + import org.apache.lucene.index.IndexReader; import org.apache.lucene.search.IndexSearcher; import org.apache.lucene.search.Query; +import org.apache.lucene.search.QueryVisitor; import org.apache.lucene.search.ScoreMode; -import java.io.IOException; -import java.util.Objects; - /** *

Wrapper to allow {@link SpanQuery} objects participate in composite * single-field SpanQueries by 'lying' about their search field. That is, @@ -104,6 +105,11 @@ public final class FieldMaskingSpanQuery extends SpanQuery { return super.rewrite(reader); } + @Override + public void visit(QueryVisitor visitor) { + maskedQuery.visit(visitor); + } + @Override public String toString(String field) { StringBuilder buffer = new StringBuilder(); diff --git a/lucene/core/src/java/org/apache/lucene/search/spans/SpanBoostQuery.java b/lucene/core/src/java/org/apache/lucene/search/spans/SpanBoostQuery.java index 9556959a3e..78918a5943 100644 --- a/lucene/core/src/java/org/apache/lucene/search/spans/SpanBoostQuery.java +++ b/lucene/core/src/java/org/apache/lucene/search/spans/SpanBoostQuery.java @@ -24,6 +24,7 @@ import org.apache.lucene.index.IndexReader; import org.apache.lucene.search.BoostQuery; import org.apache.lucene.search.IndexSearcher; import org.apache.lucene.search.Query; +import org.apache.lucene.search.QueryVisitor; import org.apache.lucene.search.ScoreMode; /** @@ -93,6 +94,11 @@ public final class SpanBoostQuery extends SpanQuery { return super.rewrite(reader); } + @Override + public void visit(QueryVisitor visitor) { + query.visit(visitor.getMatchingVisitor(this)); + } + @Override public String toString(String field) { StringBuilder builder = new StringBuilder(); diff --git a/lucene/core/src/java/org/apache/lucene/search/spans/SpanContainQuery.java b/lucene/core/src/java/org/apache/lucene/search/spans/SpanContainQuery.java index 23c1e2b829..ccd56f9cff 100644 --- a/lucene/core/src/java/org/apache/lucene/search/spans/SpanContainQuery.java +++ b/lucene/core/src/java/org/apache/lucene/search/spans/SpanContainQuery.java @@ -17,18 +17,19 @@ package org.apache.lucene.search.spans; +import java.io.IOException; +import java.util.ArrayList; +import java.util.Map; +import java.util.Objects; +import java.util.Set; + import org.apache.lucene.index.IndexReader; import org.apache.lucene.index.LeafReaderContext; import org.apache.lucene.index.Term; import org.apache.lucene.index.TermStates; import org.apache.lucene.search.IndexSearcher; import org.apache.lucene.search.Query; - -import java.io.IOException; -import java.util.ArrayList; -import java.util.Map; -import java.util.Objects; -import java.util.Set; +import org.apache.lucene.search.QueryVisitor; abstract class SpanContainQuery extends SpanQuery implements Cloneable { @@ -128,6 +129,13 @@ abstract class SpanContainQuery extends SpanQuery implements Cloneable { return super.rewrite(reader); } + @Override + public void visit(QueryVisitor visitor) { + QueryVisitor v = visitor.getMatchingVisitor(this); + big.visit(v); + little.visit(v); + } + @Override public boolean equals(Object other) { return sameClassAs(other) && diff --git a/lucene/core/src/java/org/apache/lucene/search/spans/SpanMultiTermQueryWrapper.java b/lucene/core/src/java/org/apache/lucene/search/spans/SpanMultiTermQueryWrapper.java index 088e73092d..e8f0ec8a52 100644 --- a/lucene/core/src/java/org/apache/lucene/search/spans/SpanMultiTermQueryWrapper.java +++ b/lucene/core/src/java/org/apache/lucene/search/spans/SpanMultiTermQueryWrapper.java @@ -29,6 +29,7 @@ import org.apache.lucene.search.BooleanClause.Occur; import org.apache.lucene.search.IndexSearcher; import org.apache.lucene.search.MultiTermQuery; import org.apache.lucene.search.Query; +import org.apache.lucene.search.QueryVisitor; import org.apache.lucene.search.ScoreMode; import org.apache.lucene.search.ScoringRewrite; import org.apache.lucene.search.TopTermsRewrite; @@ -121,7 +122,12 @@ public class SpanMultiTermQueryWrapper extends SpanQue public Query rewrite(IndexReader reader) throws IOException { return rewriteMethod.rewrite(reader, query); } - + + @Override + public void visit(QueryVisitor visitor) { + query.visit(visitor.getMatchingVisitor(this)); + } + @Override public int hashCode() { return classHash() * 31 + query.hashCode(); diff --git a/lucene/core/src/java/org/apache/lucene/search/spans/SpanNearQuery.java b/lucene/core/src/java/org/apache/lucene/search/spans/SpanNearQuery.java index 17b9e51513..25645d8328 100644 --- a/lucene/core/src/java/org/apache/lucene/search/spans/SpanNearQuery.java +++ b/lucene/core/src/java/org/apache/lucene/search/spans/SpanNearQuery.java @@ -33,6 +33,7 @@ import org.apache.lucene.index.TermStates; import org.apache.lucene.index.Terms; import org.apache.lucene.search.IndexSearcher; import org.apache.lucene.search.Query; +import org.apache.lucene.search.QueryVisitor; import org.apache.lucene.search.ScoreMode; import org.apache.lucene.search.Weight; @@ -265,6 +266,14 @@ public class SpanNearQuery extends SpanQuery implements Cloneable { return super.rewrite(reader); } + @Override + public void visit(QueryVisitor visitor) { + QueryVisitor v = visitor.getMatchingVisitor(this); + for (SpanQuery clause : clauses) { + clause.visit(v); + } + } + @Override public boolean equals(Object other) { return sameClassAs(other) && diff --git a/lucene/core/src/java/org/apache/lucene/search/spans/SpanNotQuery.java b/lucene/core/src/java/org/apache/lucene/search/spans/SpanNotQuery.java index 6c56df3abe..d16242698e 100644 --- a/lucene/core/src/java/org/apache/lucene/search/spans/SpanNotQuery.java +++ b/lucene/core/src/java/org/apache/lucene/search/spans/SpanNotQuery.java @@ -29,6 +29,7 @@ import org.apache.lucene.index.TermStates; import org.apache.lucene.search.DocIdSetIterator; import org.apache.lucene.search.IndexSearcher; import org.apache.lucene.search.Query; +import org.apache.lucene.search.QueryVisitor; import org.apache.lucene.search.ScoreMode; import org.apache.lucene.search.TwoPhaseIterator; @@ -209,7 +210,14 @@ public final class SpanNotQuery extends SpanQuery { } return super.rewrite(reader); } - /** Returns true iff o is equal to this. */ + + @Override + public void visit(QueryVisitor visitor) { + include.visit(visitor.getMatchingVisitor(this)); + exclude.visit(visitor.getNonMatchingVisitor(this)); + } + + /** Returns true iff o is equal to this. */ @Override public boolean equals(Object other) { return sameClassAs(other) && diff --git a/lucene/core/src/java/org/apache/lucene/search/spans/SpanOrQuery.java b/lucene/core/src/java/org/apache/lucene/search/spans/SpanOrQuery.java index 849edaa30e..facf6594e9 100644 --- a/lucene/core/src/java/org/apache/lucene/search/spans/SpanOrQuery.java +++ b/lucene/core/src/java/org/apache/lucene/search/spans/SpanOrQuery.java @@ -33,6 +33,7 @@ import org.apache.lucene.search.DisiWrapper; import org.apache.lucene.search.DisjunctionDISIApproximation; import org.apache.lucene.search.IndexSearcher; import org.apache.lucene.search.Query; +import org.apache.lucene.search.QueryVisitor; import org.apache.lucene.search.ScoreMode; import org.apache.lucene.search.TwoPhaseIterator; import org.apache.lucene.search.Weight; @@ -88,6 +89,14 @@ public final class SpanOrQuery extends SpanQuery { return super.rewrite(reader); } + @Override + public void visit(QueryVisitor visitor) { + QueryVisitor v = visitor.getShouldMatchVisitor(this); + for (SpanQuery q : clauses) { + q.visit(v); + } + } + @Override public String toString(String field) { StringBuilder buffer = new StringBuilder(); diff --git a/lucene/core/src/java/org/apache/lucene/search/spans/SpanPositionCheckQuery.java b/lucene/core/src/java/org/apache/lucene/search/spans/SpanPositionCheckQuery.java index 099b627e1e..53aba8a2f1 100644 --- a/lucene/core/src/java/org/apache/lucene/search/spans/SpanPositionCheckQuery.java +++ b/lucene/core/src/java/org/apache/lucene/search/spans/SpanPositionCheckQuery.java @@ -28,6 +28,7 @@ import org.apache.lucene.index.Term; import org.apache.lucene.index.TermStates; import org.apache.lucene.search.IndexSearcher; import org.apache.lucene.search.Query; +import org.apache.lucene.search.QueryVisitor; import org.apache.lucene.search.ScoreMode; import org.apache.lucene.search.spans.FilterSpans.AcceptStatus; @@ -126,6 +127,11 @@ public abstract class SpanPositionCheckQuery extends SpanQuery implements Clonea return super.rewrite(reader); } + @Override + public void visit(QueryVisitor visitor) { + match.visit(visitor); + } + /** Returns true iff other is equal to this. */ @Override public boolean equals(Object other) { diff --git a/lucene/core/src/java/org/apache/lucene/search/spans/SpanTermQuery.java b/lucene/core/src/java/org/apache/lucene/search/spans/SpanTermQuery.java index 42e73f2be2..d48bf057f7 100644 --- a/lucene/core/src/java/org/apache/lucene/search/spans/SpanTermQuery.java +++ b/lucene/core/src/java/org/apache/lucene/search/spans/SpanTermQuery.java @@ -33,6 +33,7 @@ import org.apache.lucene.index.TermState; import org.apache.lucene.index.Terms; import org.apache.lucene.index.TermsEnum; import org.apache.lucene.search.IndexSearcher; +import org.apache.lucene.search.QueryVisitor; import org.apache.lucene.search.ScoreMode; /** Matches spans containing a term. @@ -84,6 +85,11 @@ public class SpanTermQuery extends SpanQuery { return new SpanTermWeight(context, searcher, scoreMode.needsScores() ? Collections.singletonMap(term, context) : null, boost); } + @Override + public void visit(QueryVisitor visitor) { + visitor.visitLeaf(this, term); + } + public class SpanTermWeight extends SpanWeight { final TermStates termStates; diff --git a/lucene/core/src/test/org/apache/lucene/search/TestQueryVisitor.java b/lucene/core/src/test/org/apache/lucene/search/TestQueryVisitor.java new file mode 100644 index 0000000000..89b1428093 --- /dev/null +++ b/lucene/core/src/test/org/apache/lucene/search/TestQueryVisitor.java @@ -0,0 +1,277 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.lucene.search; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.Comparator; +import java.util.HashMap; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.function.Predicate; +import java.util.function.Supplier; + +import org.apache.lucene.index.Term; +import org.apache.lucene.search.spans.SpanNearQuery; +import org.apache.lucene.search.spans.SpanQuery; +import org.apache.lucene.search.spans.SpanTermQuery; +import org.apache.lucene.util.BytesRef; +import org.apache.lucene.util.LuceneTestCase; + +import static org.hamcrest.CoreMatchers.equalTo; + +public class TestQueryVisitor extends LuceneTestCase { + + private static final Query query = new BooleanQuery.Builder() + .add(new TermQuery(new Term("field1", "term1")), BooleanClause.Occur.MUST) + .add(new BooleanQuery.Builder() + .add(new TermQuery(new Term("field1", "term2")), BooleanClause.Occur.SHOULD) + .add(new BoostQuery(new TermQuery(new Term("field1", "term3")), 2), BooleanClause.Occur.SHOULD) + .build(), BooleanClause.Occur.MUST) + .add(new BoostQuery(new PhraseQuery.Builder() + .add(new Term("field1", "term4")) + .add(new Term("field1", "term5")) + .build(), 3), BooleanClause.Occur.MUST) + .add(new SpanNearQuery(new SpanQuery[]{ + new SpanTermQuery(new Term("field1", "term6")), + new SpanTermQuery(new Term("field1", "term7")) + }, 2, true), BooleanClause.Occur.MUST) + .add(new TermQuery(new Term("field1", "term8")), BooleanClause.Occur.MUST_NOT) + .add(new PrefixQuery(new Term("field1", "term9")), BooleanClause.Occur.SHOULD) + .build(); + + public void testExtractTermsEquivalent() { + Set terms = new HashSet<>(); + Set expected = new HashSet<>(Arrays.asList( + new Term("field1", "term1"), new Term("field1", "term2"), + new Term("field1", "term3"), new Term("field1", "term4"), + new Term("field1", "term5"), new Term("field1", "term6"), + new Term("field1", "term7") + )); + query.visit((q, t) -> terms.add(t)); + assertThat(terms, equalTo(expected)); + } + + public void extractAllTerms() { + Set terms = new HashSet<>(); + QueryVisitor visitor = new QueryVisitor() { + @Override + public void visitLeaf(Query query, Term term) { + terms.add(term); + } + @Override + public QueryVisitor getNonMatchingVisitor(Query parent) { + return this; + } + }; + Set expected = new HashSet<>(Arrays.asList( + new Term("field1", "term1"), new Term("field1", "term2"), + new Term("field1", "term3"), new Term("field1", "term4"), + new Term("field1", "term5"), new Term("field1", "term6"), + new Term("field1", "term7"), new Term("field1", "term8") + )); + query.visit(visitor); + assertThat(terms, equalTo(expected)); + } + + static class TermSelector implements QueryVisitor { + + Set terms = new HashSet<>(); + Map>>> suppliers = new HashMap<>(); + + @Override + public void visitLeaf(Query query, Term term) { + terms.add(term); + } + + @Override + public void visitLeaf(Query query, String field, Supplier> predicateSupplier) { + List>> l = suppliers.computeIfAbsent(field, f -> new ArrayList<>()); + l.add(predicateSupplier); + } + + @Override + public QueryVisitor getNonMatchingVisitor(Query parent) { + return this; // collect non-matching terms too + } + + boolean matches(Term term) { + if (terms.contains(term)) { + return true; + } + if (suppliers.containsKey(term.field())) { + for (Supplier> supplier : suppliers.get(term.field())) { + if (supplier.get().test(term.bytes())) { + return true; + } + } + } + return false; + } + } + + public void testSelectAllTerms() { + TermSelector ts = new TermSelector(); + query.visit(ts); + assertTrue(ts.matches(new Term("field1", "term1"))); + assertTrue(ts.matches(new Term("field1", "term99"))); // prefix automaton + assertFalse(ts.matches(new Term("field1", "term10"))); // must_not clause + assertFalse(ts.matches(new Term("field2", "term99"))); + } + + static class BoostedTermExtractor implements QueryVisitor { + + final float boost; + final Map termsToBoosts; + + BoostedTermExtractor(float boost, Map termsToBoosts) { + this.boost = boost; + this.termsToBoosts = termsToBoosts; + } + + @Override + public void visitLeaf(Query query, Term term) { + termsToBoosts.put(term, boost); + } + + @Override + public QueryVisitor getMatchingVisitor(Query parent) { + if (parent instanceof BoostQuery) { + return new BoostedTermExtractor(boost * ((BoostQuery)parent).getBoost(), termsToBoosts); + } + return this; + } + } + + public void testExtractTermsAndBoosts() { + Map termsToBoosts = new HashMap<>(); + query.visit(new BoostedTermExtractor(1, termsToBoosts)); + Map expected = new HashMap<>(); + expected.put(new Term("field1", "term1"), 1f); + expected.put(new Term("field1", "term2"), 1f); + expected.put(new Term("field1", "term3"), 2f); + expected.put(new Term("field1", "term4"), 3f); + expected.put(new Term("field1", "term5"), 3f); + expected.put(new Term("field1", "term6"), 1f); + expected.put(new Term("field1", "term7"), 1f); + assertThat(termsToBoosts, equalTo(expected)); + } + + static class MinimumMatchingTermSetExtractor implements QueryVisitor { + + List mustMatchLeaves = new ArrayList<>(); + List shouldMatchLeaves = new ArrayList<>(); + Term term; + int weight; + + @Override + public void visitLeaf(Query query, Term term) { + this.term = term; + this.weight = term.text().length(); + } + + @Override + public void visitLeaf(Query query, String field, Supplier> predicateSupplier) { + this.term = new Term(field, "ANY"); + this.weight = 100; + } + + @Override + public QueryVisitor getMatchingVisitor(Query parent) { + MinimumMatchingTermSetExtractor child = new MinimumMatchingTermSetExtractor(); + mustMatchLeaves.add(child); + return child; + } + + @Override + public QueryVisitor getFilteringVisitor(Query parent) { + return getMatchingVisitor(parent); + } + + @Override + public QueryVisitor getShouldMatchVisitor(Query parent) { + MinimumMatchingTermSetExtractor child = new MinimumMatchingTermSetExtractor(); + shouldMatchLeaves.add(child); + return child; + } + + int getWeight() { + if (mustMatchLeaves.size() > 0) { + mustMatchLeaves.sort(Comparator.comparingInt(MinimumMatchingTermSetExtractor::getWeight)); + return mustMatchLeaves.get(0).getWeight(); + } + if (shouldMatchLeaves.size() > 0) { + shouldMatchLeaves.sort(Comparator.comparingInt(MinimumMatchingTermSetExtractor::getWeight).reversed()); + return shouldMatchLeaves.get(0).getWeight(); + } + return weight; + } + + void getMatchesTermSet(Set terms) { + if (mustMatchLeaves.size() > 0) { + mustMatchLeaves.sort(Comparator.comparingInt(MinimumMatchingTermSetExtractor::getWeight)); + mustMatchLeaves.get(0).getMatchesTermSet(terms); + return; + } + if (shouldMatchLeaves.size() > 0) { + for (MinimumMatchingTermSetExtractor child : shouldMatchLeaves) { + child.getMatchesTermSet(terms); + } + return; + } + terms.add(term); + } + + boolean nextMatchingSet() { + if (mustMatchLeaves.size() > 0) { + if (mustMatchLeaves.get(0).nextMatchingSet()) { + return true; + } + mustMatchLeaves.remove(0); + return shouldMatchLeaves.size() > 0; + } + if (shouldMatchLeaves.size() == 0) { + return false; + } + boolean advanced = false; + for (MinimumMatchingTermSetExtractor child : shouldMatchLeaves) { + advanced |= child.nextMatchingSet(); + } + return advanced; + } + } + + public void testExtractMatchingTermSet() { + MinimumMatchingTermSetExtractor extractor = new MinimumMatchingTermSetExtractor(); + query.visit(extractor); + Set minimumTermSet = new HashSet<>(); + extractor.getMatchesTermSet(minimumTermSet); + + Set expected1 = new HashSet<>(Collections.singletonList(new Term("field1", "term1"))); + assertThat(minimumTermSet, equalTo(expected1)); + assertTrue(extractor.nextMatchingSet()); + Set expected2 = new HashSet<>(Arrays.asList(new Term("field1", "term2"), new Term("field1", "term3"))); + minimumTermSet.clear(); + extractor.getMatchesTermSet(minimumTermSet); + assertThat(minimumTermSet, equalTo(expected2)); + } + +} diff --git a/lucene/queries/src/java/org/apache/lucene/queries/CommonTermsQuery.java b/lucene/queries/src/java/org/apache/lucene/queries/CommonTermsQuery.java index 10c232ed45..fe378f5d8a 100644 --- a/lucene/queries/src/java/org/apache/lucene/queries/CommonTermsQuery.java +++ b/lucene/queries/src/java/org/apache/lucene/queries/CommonTermsQuery.java @@ -33,6 +33,7 @@ import org.apache.lucene.search.BooleanQuery; import org.apache.lucene.search.BoostQuery; import org.apache.lucene.search.MatchNoDocsQuery; import org.apache.lucene.search.Query; +import org.apache.lucene.search.QueryVisitor; import org.apache.lucene.search.TermQuery; /** @@ -129,7 +130,15 @@ public class CommonTermsQuery extends Query { collectTermStates(reader, leaves, contextArray, queryTerms); return buildQuery(maxDoc, contextArray, queryTerms); } - + + @Override + public void visit(QueryVisitor visitor) { + QueryVisitor v = visitor.getShouldMatchVisitor(this); + for (Term term : terms) { + v.visitLeaf(this, term); + } + } + protected int calcLowFreqMinimumNumberShouldMatch(int numOptional) { return minNrShouldMatch(lowFreqMinNrShouldMatch, numOptional); } diff --git a/lucene/queries/src/java/org/apache/lucene/queries/function/FunctionQuery.java b/lucene/queries/src/java/org/apache/lucene/queries/function/FunctionQuery.java index f996306a72..0a1906f1ac 100644 --- a/lucene/queries/src/java/org/apache/lucene/queries/function/FunctionQuery.java +++ b/lucene/queries/src/java/org/apache/lucene/queries/function/FunctionQuery.java @@ -150,7 +150,6 @@ public class FunctionQuery extends Query { return new FunctionQuery.FunctionWeight(searcher, boost); } - /** Prints a user-readable version of this query. */ @Override public String toString(String field) diff --git a/lucene/queries/src/java/org/apache/lucene/queries/function/FunctionScoreQuery.java b/lucene/queries/src/java/org/apache/lucene/queries/function/FunctionScoreQuery.java index 5268a430a3..1e4fe1e0c6 100644 --- a/lucene/queries/src/java/org/apache/lucene/queries/function/FunctionScoreQuery.java +++ b/lucene/queries/src/java/org/apache/lucene/queries/function/FunctionScoreQuery.java @@ -31,6 +31,7 @@ import org.apache.lucene.search.FilterScorer; import org.apache.lucene.search.IndexSearcher; import org.apache.lucene.search.Matches; import org.apache.lucene.search.Query; +import org.apache.lucene.search.QueryVisitor; import org.apache.lucene.search.ScoreMode; import org.apache.lucene.search.Scorer; import org.apache.lucene.search.Weight; @@ -119,6 +120,11 @@ public final class FunctionScoreQuery extends Query { return new FunctionScoreQuery(rewritten, source); } + @Override + public void visit(QueryVisitor visitor) { + in.visit(visitor.getMatchingVisitor(this)); + } + @Override public String toString(String field) { return "FunctionScoreQuery(" + in.toString(field) + ", scored by " + source.toString() + ")"; diff --git a/lucene/queries/src/java/org/apache/lucene/queries/mlt/MoreLikeThisQuery.java b/lucene/queries/src/java/org/apache/lucene/queries/mlt/MoreLikeThisQuery.java index 9f3310c7a2..1dc6a68a4c 100644 --- a/lucene/queries/src/java/org/apache/lucene/queries/mlt/MoreLikeThisQuery.java +++ b/lucene/queries/src/java/org/apache/lucene/queries/mlt/MoreLikeThisQuery.java @@ -16,17 +16,18 @@ */ package org.apache.lucene.queries.mlt; +import java.io.IOException; +import java.io.StringReader; +import java.util.Arrays; +import java.util.Objects; +import java.util.Set; + import org.apache.lucene.analysis.Analyzer; import org.apache.lucene.index.IndexReader; import org.apache.lucene.search.BooleanClause; import org.apache.lucene.search.BooleanQuery; import org.apache.lucene.search.Query; - -import java.io.IOException; -import java.io.StringReader; -import java.util.Arrays; -import java.util.Set; -import java.util.Objects; +import org.apache.lucene.search.QueryVisitor; /** * A simple wrapper for MoreLikeThis for use in scenarios where a Query object is required eg @@ -77,6 +78,14 @@ public class MoreLikeThisQuery extends Query { return newBq.build(); } + @Override + public void visit(QueryVisitor visitor) { + QueryVisitor v = visitor.getShouldMatchVisitor(this); + for (String field : moreLikeFields) { + v.visitLeaf(this, field, () -> t -> true); + } + } + /* (non-Javadoc) * @see org.apache.lucene.search.Query#toString(java.lang.String) */ diff --git a/lucene/queries/src/java/org/apache/lucene/queries/payloads/PayloadScoreQuery.java b/lucene/queries/src/java/org/apache/lucene/queries/payloads/PayloadScoreQuery.java index bd5d927c62..2342b442ef 100644 --- a/lucene/queries/src/java/org/apache/lucene/queries/payloads/PayloadScoreQuery.java +++ b/lucene/queries/src/java/org/apache/lucene/queries/payloads/PayloadScoreQuery.java @@ -30,6 +30,7 @@ import org.apache.lucene.search.Explanation; import org.apache.lucene.search.IndexSearcher; import org.apache.lucene.search.LeafSimScorer; import org.apache.lucene.search.Query; +import org.apache.lucene.search.QueryVisitor; import org.apache.lucene.search.ScoreMode; import org.apache.lucene.search.spans.FilterSpans; import org.apache.lucene.search.spans.SpanCollector; @@ -86,6 +87,11 @@ public class PayloadScoreQuery extends SpanQuery { return super.rewrite(reader); } + @Override + public void visit(QueryVisitor visitor) { + wrappedQuery.visit(visitor.getMatchingVisitor(this)); + } + @Override public String toString(String field) { diff --git a/lucene/queries/src/java/org/apache/lucene/queries/payloads/SpanPayloadCheckQuery.java b/lucene/queries/src/java/org/apache/lucene/queries/payloads/SpanPayloadCheckQuery.java index a9d3bfb2da..cafc0aeaa9 100644 --- a/lucene/queries/src/java/org/apache/lucene/queries/payloads/SpanPayloadCheckQuery.java +++ b/lucene/queries/src/java/org/apache/lucene/queries/payloads/SpanPayloadCheckQuery.java @@ -30,6 +30,7 @@ import org.apache.lucene.index.Terms; import org.apache.lucene.search.IndexSearcher; import org.apache.lucene.search.LeafSimScorer; import org.apache.lucene.search.Query; +import org.apache.lucene.search.QueryVisitor; import org.apache.lucene.search.ScoreMode; import org.apache.lucene.search.spans.FilterSpans; import org.apache.lucene.search.spans.FilterSpans.AcceptStatus; @@ -77,6 +78,11 @@ public class SpanPayloadCheckQuery extends SpanQuery { return super.rewrite(reader); } + @Override + public void visit(QueryVisitor visitor) { + match.visit(visitor.getMatchingVisitor(this)); + } + /** * Weight that pulls its Spans using a PayloadSpanCollector */ diff --git a/lucene/sandbox/src/java/org/apache/lucene/sandbox/queries/FuzzyLikeThisQuery.java b/lucene/sandbox/src/java/org/apache/lucene/sandbox/queries/FuzzyLikeThisQuery.java index a7898f7ae0..6a86520828 100644 --- a/lucene/sandbox/src/java/org/apache/lucene/sandbox/queries/FuzzyLikeThisQuery.java +++ b/lucene/sandbox/src/java/org/apache/lucene/sandbox/queries/FuzzyLikeThisQuery.java @@ -41,6 +41,7 @@ import org.apache.lucene.search.ConstantScoreQuery; import org.apache.lucene.search.FuzzyTermsEnum; import org.apache.lucene.search.MaxNonCompetitiveBoostAttribute; import org.apache.lucene.search.Query; +import org.apache.lucene.search.QueryVisitor; import org.apache.lucene.search.TermQuery; import org.apache.lucene.search.similarities.ClassicSimilarity; import org.apache.lucene.search.similarities.TFIDFSimilarity; @@ -330,7 +331,15 @@ public class FuzzyLikeThisQuery extends Query // booleans with a minimum-should-match of NumFields-1? return bq.build(); } - + + @Override + public void visit(QueryVisitor visitor) { + QueryVisitor v = visitor.getShouldMatchVisitor(this); + for (FieldVals vals : fieldVals) { + v.visitLeaf(this, vals.fieldName, () -> t -> true); // TODO build automaton from FuzzyTermsEnum? + } + } + //Holds info for a fuzzy term variant - initially score is set to edit distance (for ranking best // term variants) then is reset with IDF for use in ranking against all other // terms/fields diff --git a/lucene/sandbox/src/java/org/apache/lucene/search/BM25FQuery.java b/lucene/sandbox/src/java/org/apache/lucene/search/BM25FQuery.java index b02989df46..1c2f4e3f0a 100644 --- a/lucene/sandbox/src/java/org/apache/lucene/search/BM25FQuery.java +++ b/lucene/sandbox/src/java/org/apache/lucene/search/BM25FQuery.java @@ -219,6 +219,14 @@ public final class BM25FQuery extends Query { return this; } + @Override + public void visit(QueryVisitor visitor) { + QueryVisitor v = visitor.getShouldMatchVisitor(this); + for (Term term : fieldTerms) { + v.visitLeaf(this, term); + } + } + private BooleanQuery rewriteToBoolean() { // rewrite to a simple disjunction if the score is not needed. BooleanQuery.Builder bq = new BooleanQuery.Builder(); diff --git a/lucene/sandbox/src/java/org/apache/lucene/search/CoveringQuery.java b/lucene/sandbox/src/java/org/apache/lucene/search/CoveringQuery.java index fd89888bb8..a4f250d0d8 100644 --- a/lucene/sandbox/src/java/org/apache/lucene/search/CoveringQuery.java +++ b/lucene/sandbox/src/java/org/apache/lucene/search/CoveringQuery.java @@ -109,6 +109,14 @@ public final class CoveringQuery extends Query { return super.rewrite(reader); } + @Override + public void visit(QueryVisitor visitor) { + QueryVisitor v = visitor.getShouldMatchVisitor(this); + for (Query query : queries) { + query.visit(v); + } + } + @Override public Weight createWeight(IndexSearcher searcher, ScoreMode scoreMode, float boost) throws IOException { final List weights = new ArrayList<>(queries.size()); diff --git a/lucene/sandbox/src/java/org/apache/lucene/search/DocValuesNumbersQuery.java b/lucene/sandbox/src/java/org/apache/lucene/search/DocValuesNumbersQuery.java index f018df4ca1..3ec0f5da28 100644 --- a/lucene/sandbox/src/java/org/apache/lucene/search/DocValuesNumbersQuery.java +++ b/lucene/sandbox/src/java/org/apache/lucene/search/DocValuesNumbersQuery.java @@ -129,4 +129,5 @@ public class DocValuesNumbersQuery extends Query { }; } + } diff --git a/lucene/sandbox/src/java/org/apache/lucene/search/TermAutomatonQuery.java b/lucene/sandbox/src/java/org/apache/lucene/search/TermAutomatonQuery.java index 81426a7ffe..1081faf585 100644 --- a/lucene/sandbox/src/java/org/apache/lucene/search/TermAutomatonQuery.java +++ b/lucene/sandbox/src/java/org/apache/lucene/search/TermAutomatonQuery.java @@ -492,4 +492,12 @@ public class TermAutomatonQuery extends Query { // TODO: we could maybe also rewrite to union of PhraseQuery (pull all finite strings) if it's "worth it"? return this; } + + @Override + public void visit(QueryVisitor visitor) { + QueryVisitor v = visitor.getShouldMatchVisitor(this); + for (BytesRef term : termToID.keySet()) { + v.visitLeaf(this, new Term(field, term)); + } + } } diff --git a/lucene/sandbox/src/java/org/apache/lucene/search/intervals/ConjunctionIntervalsSource.java b/lucene/sandbox/src/java/org/apache/lucene/search/intervals/ConjunctionIntervalsSource.java index ec4341de70..5e42ca807b 100644 --- a/lucene/sandbox/src/java/org/apache/lucene/search/intervals/ConjunctionIntervalsSource.java +++ b/lucene/sandbox/src/java/org/apache/lucene/search/intervals/ConjunctionIntervalsSource.java @@ -21,15 +21,14 @@ import java.io.IOException; import java.util.ArrayList; import java.util.List; import java.util.Objects; -import java.util.Set; import java.util.stream.Collectors; import org.apache.lucene.index.LeafReaderContext; -import org.apache.lucene.index.Term; import org.apache.lucene.search.FilterMatchesIterator; import org.apache.lucene.search.MatchesIterator; import org.apache.lucene.search.MatchesUtils; import org.apache.lucene.search.Query; +import org.apache.lucene.search.QueryVisitor; class ConjunctionIntervalsSource extends IntervalsSource { @@ -56,9 +55,10 @@ class ConjunctionIntervalsSource extends IntervalsSource { } @Override - public void extractTerms(String field, Set terms) { + public void visit(String field, QueryVisitor visitor) { + QueryVisitor v = visitor.getMatchingVisitor(new IntervalQuery(field, this)); for (IntervalsSource source : subSources) { - source.extractTerms(field, terms); + source.visit(field, v); } } diff --git a/lucene/sandbox/src/java/org/apache/lucene/search/intervals/DifferenceIntervalsSource.java b/lucene/sandbox/src/java/org/apache/lucene/search/intervals/DifferenceIntervalsSource.java index 7289d04ba2..823151318c 100644 --- a/lucene/sandbox/src/java/org/apache/lucene/search/intervals/DifferenceIntervalsSource.java +++ b/lucene/sandbox/src/java/org/apache/lucene/search/intervals/DifferenceIntervalsSource.java @@ -19,11 +19,10 @@ package org.apache.lucene.search.intervals; import java.io.IOException; import java.util.Objects; -import java.util.Set; import org.apache.lucene.index.LeafReaderContext; -import org.apache.lucene.index.Term; import org.apache.lucene.search.MatchesIterator; +import org.apache.lucene.search.QueryVisitor; class DifferenceIntervalsSource extends IntervalsSource { @@ -83,8 +82,10 @@ class DifferenceIntervalsSource extends IntervalsSource { } @Override - public void extractTerms(String field, Set terms) { - minuend.extractTerms(field, terms); + public void visit(String field, QueryVisitor visitor) { + IntervalQuery q = new IntervalQuery(field, this); + minuend.visit(field, visitor.getMatchingVisitor(q)); + subtrahend.visit(field, visitor.getNonMatchingVisitor(q)); } @Override diff --git a/lucene/sandbox/src/java/org/apache/lucene/search/intervals/DisjunctionIntervalsSource.java b/lucene/sandbox/src/java/org/apache/lucene/search/intervals/DisjunctionIntervalsSource.java index b28088513a..033e5b91dc 100644 --- a/lucene/sandbox/src/java/org/apache/lucene/search/intervals/DisjunctionIntervalsSource.java +++ b/lucene/sandbox/src/java/org/apache/lucene/search/intervals/DisjunctionIntervalsSource.java @@ -21,14 +21,13 @@ import java.io.IOException; import java.util.ArrayList; import java.util.List; import java.util.Objects; -import java.util.Set; import java.util.stream.Collectors; import org.apache.lucene.index.LeafReaderContext; -import org.apache.lucene.index.Term; import org.apache.lucene.search.DocIdSetIterator; import org.apache.lucene.search.MatchesIterator; import org.apache.lucene.search.MatchesUtils; +import org.apache.lucene.search.QueryVisitor; import org.apache.lucene.util.PriorityQueue; class DisjunctionIntervalsSource extends IntervalsSource { @@ -84,9 +83,10 @@ class DisjunctionIntervalsSource extends IntervalsSource { } @Override - public void extractTerms(String field, Set terms) { + public void visit(String field, QueryVisitor visitor) { + QueryVisitor v = visitor.getShouldMatchVisitor(new IntervalQuery(field, this)); for (IntervalsSource source : subSources) { - source.extractTerms(field, terms); + source.visit(field, v); } } diff --git a/lucene/sandbox/src/java/org/apache/lucene/search/intervals/ExtendedIntervalsSource.java b/lucene/sandbox/src/java/org/apache/lucene/search/intervals/ExtendedIntervalsSource.java index 864a4b573c..41a8829b94 100644 --- a/lucene/sandbox/src/java/org/apache/lucene/search/intervals/ExtendedIntervalsSource.java +++ b/lucene/sandbox/src/java/org/apache/lucene/search/intervals/ExtendedIntervalsSource.java @@ -19,11 +19,10 @@ package org.apache.lucene.search.intervals; import java.io.IOException; import java.util.Objects; -import java.util.Set; import org.apache.lucene.index.LeafReaderContext; -import org.apache.lucene.index.Term; import org.apache.lucene.search.MatchesIterator; +import org.apache.lucene.search.QueryVisitor; class ExtendedIntervalsSource extends IntervalsSource { @@ -57,8 +56,8 @@ class ExtendedIntervalsSource extends IntervalsSource { } @Override - public void extractTerms(String field, Set terms) { - source.extractTerms(field, terms); + public void visit(String field, QueryVisitor visitor) { + source.visit(field, visitor); } @Override diff --git a/lucene/sandbox/src/java/org/apache/lucene/search/intervals/FilteredIntervalsSource.java b/lucene/sandbox/src/java/org/apache/lucene/search/intervals/FilteredIntervalsSource.java index c2b4d6012c..b3341297c9 100644 --- a/lucene/sandbox/src/java/org/apache/lucene/search/intervals/FilteredIntervalsSource.java +++ b/lucene/sandbox/src/java/org/apache/lucene/search/intervals/FilteredIntervalsSource.java @@ -19,11 +19,10 @@ package org.apache.lucene.search.intervals; import java.io.IOException; import java.util.Objects; -import java.util.Set; import org.apache.lucene.index.LeafReaderContext; -import org.apache.lucene.index.Term; import org.apache.lucene.search.MatchesIterator; +import org.apache.lucene.search.QueryVisitor; /** * An IntervalsSource that filters the intervals from another IntervalsSource @@ -83,8 +82,8 @@ public abstract class FilteredIntervalsSource extends IntervalsSource { } @Override - public void extractTerms(String field, Set terms) { - in.extractTerms(field, terms); + public void visit(String field, QueryVisitor visitor) { + in.visit(field, visitor); } @Override diff --git a/lucene/sandbox/src/java/org/apache/lucene/search/intervals/FixedFieldIntervalsSource.java b/lucene/sandbox/src/java/org/apache/lucene/search/intervals/FixedFieldIntervalsSource.java index 7776a2b543..ab24ee359e 100644 --- a/lucene/sandbox/src/java/org/apache/lucene/search/intervals/FixedFieldIntervalsSource.java +++ b/lucene/sandbox/src/java/org/apache/lucene/search/intervals/FixedFieldIntervalsSource.java @@ -19,11 +19,10 @@ package org.apache.lucene.search.intervals; import java.io.IOException; import java.util.Objects; -import java.util.Set; import org.apache.lucene.index.LeafReaderContext; -import org.apache.lucene.index.Term; import org.apache.lucene.search.MatchesIterator; +import org.apache.lucene.search.QueryVisitor; class FixedFieldIntervalsSource extends IntervalsSource { @@ -46,8 +45,8 @@ class FixedFieldIntervalsSource extends IntervalsSource { } @Override - public void extractTerms(String field, Set terms) { - source.extractTerms(this.field, terms); + public void visit(String field, QueryVisitor visitor) { + source.visit(this.field, visitor); } @Override diff --git a/lucene/sandbox/src/java/org/apache/lucene/search/intervals/IntervalQuery.java b/lucene/sandbox/src/java/org/apache/lucene/search/intervals/IntervalQuery.java index 62fe0679bf..d25eaa363c 100644 --- a/lucene/sandbox/src/java/org/apache/lucene/search/intervals/IntervalQuery.java +++ b/lucene/sandbox/src/java/org/apache/lucene/search/intervals/IntervalQuery.java @@ -30,6 +30,7 @@ import org.apache.lucene.search.Matches; import org.apache.lucene.search.MatchesIterator; import org.apache.lucene.search.MatchesUtils; import org.apache.lucene.search.Query; +import org.apache.lucene.search.QueryVisitor; import org.apache.lucene.search.ScoreMode; import org.apache.lucene.search.Scorer; import org.apache.lucene.search.Weight; @@ -120,6 +121,11 @@ public final class IntervalQuery extends Query { return new IntervalWeight(this, boost, scoreMode); } + @Override + public void visit(QueryVisitor visitor) { + intervalsSource.visit(field, visitor); + } + @Override public boolean equals(Object o) { if (this == o) return true; @@ -147,7 +153,7 @@ public final class IntervalQuery extends Query { @Override public void extractTerms(Set terms) { - intervalsSource.extractTerms(field, terms); + intervalsSource.visit(field, (q, t) -> terms.add(t)); } @Override diff --git a/lucene/sandbox/src/java/org/apache/lucene/search/intervals/IntervalsSource.java b/lucene/sandbox/src/java/org/apache/lucene/search/intervals/IntervalsSource.java index dc4161fa05..1c49d248fa 100644 --- a/lucene/sandbox/src/java/org/apache/lucene/search/intervals/IntervalsSource.java +++ b/lucene/sandbox/src/java/org/apache/lucene/search/intervals/IntervalsSource.java @@ -18,11 +18,10 @@ package org.apache.lucene.search.intervals; import java.io.IOException; -import java.util.Set; import org.apache.lucene.index.LeafReaderContext; -import org.apache.lucene.index.Term; import org.apache.lucene.search.MatchesIterator; +import org.apache.lucene.search.QueryVisitor; /** * A helper class for {@link IntervalQuery} that provides an {@link IntervalIterator} @@ -56,11 +55,9 @@ public abstract class IntervalsSource { public abstract MatchesIterator matches(String field, LeafReaderContext ctx, int doc) throws IOException; /** - * Expert: collect {@link Term} objects from this source - * @param field the field to be scored - * @param terms a {@link Set} which terms should be added to + * Expert: visit the tree of sources */ - public abstract void extractTerms(String field, Set terms); + public abstract void visit(String field, QueryVisitor visitor); /** * Return the minimum possible width of an interval returned by this source diff --git a/lucene/sandbox/src/java/org/apache/lucene/search/intervals/MinimumShouldMatchIntervalsSource.java b/lucene/sandbox/src/java/org/apache/lucene/search/intervals/MinimumShouldMatchIntervalsSource.java index 1935c628ee..96f3c571ea 100644 --- a/lucene/sandbox/src/java/org/apache/lucene/search/intervals/MinimumShouldMatchIntervalsSource.java +++ b/lucene/sandbox/src/java/org/apache/lucene/search/intervals/MinimumShouldMatchIntervalsSource.java @@ -25,15 +25,14 @@ import java.util.IdentityHashMap; import java.util.List; import java.util.Map; import java.util.Objects; -import java.util.Set; import java.util.stream.Collectors; import org.apache.lucene.index.LeafReaderContext; -import org.apache.lucene.index.Term; import org.apache.lucene.search.DocIdSetIterator; import org.apache.lucene.search.MatchesIterator; import org.apache.lucene.search.MatchesUtils; import org.apache.lucene.search.Query; +import org.apache.lucene.search.QueryVisitor; import org.apache.lucene.util.PriorityQueue; class MinimumShouldMatchIntervalsSource extends IntervalsSource { @@ -85,9 +84,10 @@ class MinimumShouldMatchIntervalsSource extends IntervalsSource { } @Override - public void extractTerms(String field, Set terms) { + public void visit(String field, QueryVisitor visitor) { + QueryVisitor v = visitor.getShouldMatchVisitor(new IntervalQuery(field, this)); for (IntervalsSource source : sources) { - source.extractTerms(field, terms); + source.visit(field, v); } } diff --git a/lucene/sandbox/src/java/org/apache/lucene/search/intervals/MultiTermIntervalsSource.java b/lucene/sandbox/src/java/org/apache/lucene/search/intervals/MultiTermIntervalsSource.java index 7689d1dc48..a8a7b810e1 100644 --- a/lucene/sandbox/src/java/org/apache/lucene/search/intervals/MultiTermIntervalsSource.java +++ b/lucene/sandbox/src/java/org/apache/lucene/search/intervals/MultiTermIntervalsSource.java @@ -21,14 +21,13 @@ import java.io.IOException; import java.util.ArrayList; import java.util.List; import java.util.Objects; -import java.util.Set; import org.apache.lucene.index.LeafReaderContext; -import org.apache.lucene.index.Term; import org.apache.lucene.index.Terms; import org.apache.lucene.index.TermsEnum; import org.apache.lucene.search.MatchesIterator; import org.apache.lucene.search.MatchesUtils; +import org.apache.lucene.search.QueryVisitor; import org.apache.lucene.util.BytesRef; import org.apache.lucene.util.automaton.CompiledAutomaton; @@ -89,8 +88,8 @@ class MultiTermIntervalsSource extends IntervalsSource { } @Override - public void extractTerms(String field, Set terms) { - + public void visit(String field, QueryVisitor visitor) { + visitor.visitLeaf(new IntervalQuery(field, this), field, () -> QueryVisitor.matchesAutomaton(automaton)); } @Override diff --git a/lucene/sandbox/src/java/org/apache/lucene/search/intervals/OffsetIntervalsSource.java b/lucene/sandbox/src/java/org/apache/lucene/search/intervals/OffsetIntervalsSource.java index b2ca30224e..078c64b027 100644 --- a/lucene/sandbox/src/java/org/apache/lucene/search/intervals/OffsetIntervalsSource.java +++ b/lucene/sandbox/src/java/org/apache/lucene/search/intervals/OffsetIntervalsSource.java @@ -19,11 +19,10 @@ package org.apache.lucene.search.intervals; import java.io.IOException; import java.util.Objects; -import java.util.Set; import org.apache.lucene.index.LeafReaderContext; -import org.apache.lucene.index.Term; import org.apache.lucene.search.MatchesIterator; +import org.apache.lucene.search.QueryVisitor; /** * Tracks a reference intervals source, and produces a pseudo-interval that appears @@ -144,8 +143,8 @@ class OffsetIntervalsSource extends IntervalsSource { } @Override - public void extractTerms(String field, Set terms) { - in.extractTerms(field, terms); + public void visit(String field, QueryVisitor visitor) { + in.visit(field, visitor); } @Override diff --git a/lucene/sandbox/src/java/org/apache/lucene/search/intervals/TermIntervalsSource.java b/lucene/sandbox/src/java/org/apache/lucene/search/intervals/TermIntervalsSource.java index 4539d2f2cb..6dddde056d 100644 --- a/lucene/sandbox/src/java/org/apache/lucene/search/intervals/TermIntervalsSource.java +++ b/lucene/sandbox/src/java/org/apache/lucene/search/intervals/TermIntervalsSource.java @@ -19,7 +19,6 @@ package org.apache.lucene.search.intervals; import java.io.IOException; import java.util.Objects; -import java.util.Set; import org.apache.lucene.codecs.lucene50.Lucene50PostingsFormat; import org.apache.lucene.codecs.lucene50.Lucene50PostingsReader; @@ -32,6 +31,7 @@ import org.apache.lucene.index.TermsEnum; import org.apache.lucene.search.DocIdSetIterator; import org.apache.lucene.search.MatchesIterator; import org.apache.lucene.search.Query; +import org.apache.lucene.search.QueryVisitor; import org.apache.lucene.search.TwoPhaseIterator; import org.apache.lucene.util.BytesRef; @@ -227,8 +227,8 @@ class TermIntervalsSource extends IntervalsSource { } @Override - public void extractTerms(String field, Set terms) { - terms.add(new Term(field, term)); + public void visit(String field, QueryVisitor visitor) { + visitor.visitLeaf(new IntervalQuery(field, this), new Term(field, term)); } /** A guess of diff --git a/lucene/test-framework/src/java/org/apache/lucene/search/AssertingQuery.java b/lucene/test-framework/src/java/org/apache/lucene/search/AssertingQuery.java index b3d2f8116c..8989106b95 100644 --- a/lucene/test-framework/src/java/org/apache/lucene/search/AssertingQuery.java +++ b/lucene/test-framework/src/java/org/apache/lucene/search/AssertingQuery.java @@ -78,4 +78,9 @@ public final class AssertingQuery extends Query { } } + @Override + public void visit(QueryVisitor visitor) { + in.visit(visitor); + } + } diff --git a/lucene/test-framework/src/java/org/apache/lucene/search/spans/AssertingSpanQuery.java b/lucene/test-framework/src/java/org/apache/lucene/search/spans/AssertingSpanQuery.java index f24a4ff8fe..1040a3a767 100644 --- a/lucene/test-framework/src/java/org/apache/lucene/search/spans/AssertingSpanQuery.java +++ b/lucene/test-framework/src/java/org/apache/lucene/search/spans/AssertingSpanQuery.java @@ -16,14 +16,15 @@ */ package org.apache.lucene.search.spans; +import java.io.IOException; +import java.util.Objects; + import org.apache.lucene.index.IndexReader; import org.apache.lucene.search.IndexSearcher; import org.apache.lucene.search.Query; +import org.apache.lucene.search.QueryVisitor; import org.apache.lucene.search.ScoreMode; -import java.io.IOException; -import java.util.Objects; - /** Wraps a span query with asserts */ public class AssertingSpanQuery extends SpanQuery { private final SpanQuery in; @@ -60,6 +61,11 @@ public class AssertingSpanQuery extends SpanQuery { } } + @Override + public void visit(QueryVisitor visitor) { + in.visit(visitor); + } + @Override public Query clone() { return new AssertingSpanQuery(in);