diff --git lucene/suggest/src/java/org/apache/lucene/search/spell/HighFrequencyDictionary.java lucene/suggest/src/java/org/apache/lucene/search/spell/HighFrequencyDictionary.java index 826ba28..2245d40 100644 --- lucene/suggest/src/java/org/apache/lucene/search/spell/HighFrequencyDictionary.java +++ lucene/suggest/src/java/org/apache/lucene/search/spell/HighFrequencyDictionary.java @@ -18,6 +18,7 @@ package org.apache.lucene.search.spell; import java.io.IOException; +import java.util.Set; import org.apache.lucene.index.IndexReader; import org.apache.lucene.index.TermsEnum; @@ -109,5 +110,15 @@ public class HighFrequencyDictionary implements Dictionary { public boolean hasPayloads() { return false; } + + @Override + public Set contexts() { + return null; + } + + @Override + public boolean hasContexts() { + return false; + } } } diff --git lucene/suggest/src/java/org/apache/lucene/search/suggest/BufferedInputIterator.java lucene/suggest/src/java/org/apache/lucene/search/suggest/BufferedInputIterator.java index b9772fa..b3c5f0b 100644 --- lucene/suggest/src/java/org/apache/lucene/search/suggest/BufferedInputIterator.java +++ lucene/suggest/src/java/org/apache/lucene/search/suggest/BufferedInputIterator.java @@ -18,6 +18,7 @@ package org.apache.lucene.search.suggest; */ import java.io.IOException; +import java.util.Set; import org.apache.lucene.util.ArrayUtil; import org.apache.lucene.util.BytesRef; @@ -85,4 +86,14 @@ public class BufferedInputIterator implements InputIterator { public boolean hasPayloads() { return hasPayloads; } + + @Override + public Set contexts() { + return null; + } + + @Override + public boolean hasContexts() { + return false; + } } diff --git lucene/suggest/src/java/org/apache/lucene/search/suggest/DocumentDictionary.java lucene/suggest/src/java/org/apache/lucene/search/suggest/DocumentDictionary.java index d948e20..03f7ec0 100644 --- lucene/suggest/src/java/org/apache/lucene/search/suggest/DocumentDictionary.java +++ lucene/suggest/src/java/org/apache/lucene/search/suggest/DocumentDictionary.java @@ -192,5 +192,15 @@ public class DocumentDictionary implements Dictionary { } return relevantFields; } + + @Override + public Set contexts() { + return null; + } + + @Override + public boolean hasContexts() { + return false; + } } } diff --git lucene/suggest/src/java/org/apache/lucene/search/suggest/FileDictionary.java lucene/suggest/src/java/org/apache/lucene/search/suggest/FileDictionary.java index 5e59685..74683d9 100644 --- lucene/suggest/src/java/org/apache/lucene/search/suggest/FileDictionary.java +++ lucene/suggest/src/java/org/apache/lucene/search/suggest/FileDictionary.java @@ -19,6 +19,7 @@ package org.apache.lucene.search.suggest; import java.io.*; +import java.util.Set; import org.apache.lucene.search.spell.Dictionary; import org.apache.lucene.util.BytesRef; @@ -209,5 +210,15 @@ public class FileDictionary implements Dictionary { curWeight = (long)Double.parseDouble(weight); } } + + @Override + public Set contexts() { + return null; + } + + @Override + public boolean hasContexts() { + return false; + } } } diff --git lucene/suggest/src/java/org/apache/lucene/search/suggest/InputIterator.java lucene/suggest/src/java/org/apache/lucene/search/suggest/InputIterator.java index bda1332..fc35cdc 100644 --- lucene/suggest/src/java/org/apache/lucene/search/suggest/InputIterator.java +++ lucene/suggest/src/java/org/apache/lucene/search/suggest/InputIterator.java @@ -18,6 +18,7 @@ package org.apache.lucene.search.suggest; */ import java.io.IOException; +import java.util.Set; import org.apache.lucene.search.suggest.Lookup.LookupResult; // javadocs import org.apache.lucene.search.suggest.analyzing.AnalyzingInfixSuggester; // javadocs @@ -44,6 +45,12 @@ public interface InputIterator extends BytesRefIterator { /** Returns true if the iterator has payloads */ public boolean hasPayloads(); + /** A term's contexts context can be used to filter suggestions*/ + public Set contexts(); + + /** Returns true if the iterator has contexts */ + public boolean hasContexts(); + /** * Wraps a BytesRefIterator as a suggester InputIterator, with all weights * set to 1 and carries no payload @@ -79,5 +86,15 @@ public interface InputIterator extends BytesRefIterator { public boolean hasPayloads() { return false; } + + @Override + public Set contexts() { + return null; + } + + @Override + public boolean hasContexts() { + return false; + } } } diff --git lucene/suggest/src/java/org/apache/lucene/search/suggest/Lookup.java lucene/suggest/src/java/org/apache/lucene/search/suggest/Lookup.java index 3b4e09c..c117476 100644 --- lucene/suggest/src/java/org/apache/lucene/search/suggest/Lookup.java +++ lucene/suggest/src/java/org/apache/lucene/search/suggest/Lookup.java @@ -22,6 +22,7 @@ import java.io.InputStream; import java.io.OutputStream; import java.util.Comparator; import java.util.List; +import java.util.Set; import org.apache.lucene.search.spell.Dictionary; import org.apache.lucene.util.BytesRef; @@ -50,31 +51,53 @@ public abstract class Lookup { /** the key's payload (null if not present) */ public final BytesRef payload; + /** the key's contexts (null if not present) */ + public final Set contexts; + /** * Create a new result from a key+weight pair. */ public LookupResult(CharSequence key, long value) { - this(key, value, null); + this(key, null, value, null, null); } /** * Create a new result from a key+weight+payload triple. */ public LookupResult(CharSequence key, long value, BytesRef payload) { - this.key = key; - this.highlightKey = null; - this.value = value; - this.payload = payload; + this(key, null, value, payload, null); } - + /** * Create a new result from a key+highlightKey+weight+payload triple. */ public LookupResult(CharSequence key, Object highlightKey, long value, BytesRef payload) { + this(key, highlightKey, value, payload, null); + } + + /** + * Create a new result from a key+weight+payload+contexts triple. + */ + public LookupResult(CharSequence key, long value, BytesRef payload, Set contexts) { + this(key, null, value, payload, contexts); + } + + /** + * Create a new result from a key+weight+contexts triple. + */ + public LookupResult(CharSequence key, long value, Set contexts) { + this(key, null, value, null, contexts); + } + + /** + * Create a new result from a key+highlightKey+weight+payload+contexts triple. + */ + public LookupResult(CharSequence key, Object highlightKey, long value, BytesRef payload, Set contexts) { this.key = key; this.highlightKey = highlightKey; this.value = value; this.payload = payload; + this.contexts = contexts; } @Override @@ -177,11 +200,25 @@ public abstract class Lookup { * Look up a key and return possible completion for this key. * @param key lookup key. Depending on the implementation this may be * a prefix, misspelling, or even infix. + * @param contexts contexts to filter the lookup by + * @param num maximum number of results to return + * @return a list of possible completions, with their relative weight (e.g. popularity) + */ + public List lookup(CharSequence key, Set contexts, int num) { + throw new UnsupportedOperationException("contexts is not supported"); + } + + /** + * Look up a key and return possible completion for this key. + * @param key lookup key. Depending on the implementation this may be + * a prefix, misspelling, or even infix. * @param onlyMorePopular return only more popular results * @param num maximum number of results to return * @return a list of possible completions, with their relative weight (e.g. popularity) */ - public abstract List lookup(CharSequence key, boolean onlyMorePopular, int num); + public List lookup(CharSequence key, boolean onlyMorePopular, int num) { + throw new UnsupportedOperationException("contexts is required"); + } /** @@ -206,4 +243,5 @@ public abstract class Lookup { * @return ram size of the lookup implementation in bytes */ public abstract long sizeInBytes(); + } diff --git lucene/suggest/src/java/org/apache/lucene/search/suggest/SortedInputIterator.java lucene/suggest/src/java/org/apache/lucene/search/suggest/SortedInputIterator.java index d804f38..c022c11 100644 --- lucene/suggest/src/java/org/apache/lucene/search/suggest/SortedInputIterator.java +++ lucene/suggest/src/java/org/apache/lucene/search/suggest/SortedInputIterator.java @@ -20,6 +20,7 @@ package org.apache.lucene.search.suggest; import java.io.File; import java.io.IOException; import java.util.Comparator; +import java.util.Set; import org.apache.lucene.search.suggest.Sort.ByteSequencesReader; import org.apache.lucene.search.suggest.Sort.ByteSequencesWriter; @@ -223,4 +224,14 @@ public class SortedInputIterator implements InputIterator { scratch.length -= payloadLength; // payload return payloadScratch; } + + @Override + public Set contexts() { + return source.contexts(); + } + + @Override + public boolean hasContexts() { + return source.hasContexts(); + } } diff --git lucene/suggest/src/java/org/apache/lucene/search/suggest/analyzing/AnalyzingSuggester.java lucene/suggest/src/java/org/apache/lucene/search/suggest/analyzing/AnalyzingSuggester.java index 4278440..5b56ced 100644 --- lucene/suggest/src/java/org/apache/lucene/search/suggest/analyzing/AnalyzingSuggester.java +++ lucene/suggest/src/java/org/apache/lucene/search/suggest/analyzing/AnalyzingSuggester.java @@ -22,6 +22,7 @@ import java.io.IOException; import java.io.InputStream; import java.io.OutputStream; import java.util.ArrayList; +import java.util.Arrays; import java.util.Collections; import java.util.Comparator; import java.util.HashSet; @@ -182,7 +183,10 @@ public class AnalyzingSuggester extends Lookup { private boolean hasPayloads; + private boolean hasContexts; + private static final int PAYLOAD_SEP = '\u001f'; + /** Whether position holes should appear in the automaton. */ private boolean preservePositionIncrements; @@ -319,9 +323,11 @@ public class AnalyzingSuggester extends Lookup { private static class AnalyzingComparator implements Comparator { private final boolean hasPayloads; + private final boolean hasContexts; - public AnalyzingComparator(boolean hasPayloads) { + public AnalyzingComparator(boolean hasPayloads, boolean hasContexts) { this.hasPayloads = hasPayloads; + this.hasContexts = hasContexts; } private final ByteArrayDataInput readerA = new ByteArrayDataInput(); @@ -367,7 +373,18 @@ public class AnalyzingSuggester extends Lookup { scratchB.length = readerB.readShort(); scratchA.offset = readerA.getPosition(); scratchB.offset = readerB.getPosition(); - } else { + } + if (hasContexts) { + if (hasPayloads) { + readerA.skipBytes(scratchA.length); + readerB.skipBytes(scratchB.length); + } + scratchA.length = readerA.readShort(); + scratchB.length = readerB.readShort(); + scratchA.offset = readerA.getPosition(); + scratchB.offset = readerB.getPosition(); + } + if (!hasPayloads && !hasContexts) { scratchA.offset = readerA.getPosition(); scratchB.offset = readerB.getPosition(); scratchA.length = a.length - scratchA.offset; @@ -377,15 +394,52 @@ public class AnalyzingSuggester extends Lookup { return scratchA.compareTo(scratchB); } } + + /** Returns a set of context+term */ + protected Set getTermsWithContext(Set contexts, BytesRef scratch) { + return new HashSet<>(Arrays.asList(scratch)); + } + /** contexts -> BytesRef representation of contexts seperated by {{@link #PAYLOAD_SEP}*/ + private BytesRef getContextsPayload(Set contexts) { + if (contexts != null) { + int requiredLen = contexts.size() - 1; // for the seperators + Object[] contextArray = contexts.toArray(); + for (int i = 0; i < contextArray.length; i++) { + requiredLen += ((BytesRef)contextArray[i]).length; + } + BytesRef contextsPayload = new BytesRef(requiredLen); + int currentOffset = 0; + BytesRef context; + for (int i = 0; i < contextArray.length; i++) { + context = ((BytesRef)contextArray[i]); + System.arraycopy(context.bytes, context.offset, contextsPayload.bytes, currentOffset, context.length); + + if (i != contextArray.length - 1) { + currentOffset += context.length; + contextsPayload.bytes[currentOffset] = PAYLOAD_SEP; + currentOffset++; + } + + } + contextsPayload.length = contextsPayload.bytes.length; + return contextsPayload; + } + return null; + } + @Override public void build(InputIterator iterator) throws IOException { String prefix = getClass().getSimpleName(); + if (prefix.length() == 0) { + prefix = getClass().getEnclosingClass().getSimpleName(); + } File directory = Sort.defaultTempDir(); File tempInput = File.createTempFile(prefix, ".input", directory); File tempSorted = File.createTempFile(prefix, ".sorted", directory); hasPayloads = iterator.hasPayloads(); + hasContexts = iterator.hasContexts(); Sort.ByteSequencesWriter writer = new Sort.ByteSequencesWriter(tempInput); Sort.ByteSequencesReader reader = null; @@ -403,66 +457,97 @@ public class AnalyzingSuggester extends Lookup { Set paths = toFiniteStrings(surfaceForm, ts2a); maxAnalyzedPathsForOneInput = Math.max(maxAnalyzedPathsForOneInput, paths.size()); - + BytesRef contexts = (hasContexts) ? getContextsPayload(iterator.contexts()) : null; + for (IntsRef path : paths) { - Util.toBytesRef(path, scratch); - - // length of the analyzed text (FST input) - if (scratch.length > Short.MAX_VALUE-2) { - throw new IllegalArgumentException("cannot handle analyzed forms > " + (Short.MAX_VALUE-2) + " in length (got " + scratch.length + ")"); - } - short analyzedLength = (short) scratch.length; - - // compute the required length: - // analyzed sequence + weight (4) + surface + analyzedLength (short) - int requiredLength = analyzedLength + 4 + surfaceForm.length + 2; - - BytesRef payload; - - if (hasPayloads) { - if (surfaceForm.length > (Short.MAX_VALUE-2)) { - throw new IllegalArgumentException("cannot handle surface form > " + (Short.MAX_VALUE-2) + " in length (got " + surfaceForm.length + ")"); + for (BytesRef termWithContext: getTermsWithContext(iterator.contexts(), scratch)) { + //BytesRef termWithContext = termEntry.getValue(); + // length of the analyzed text (FST input) + if (termWithContext.length > Short.MAX_VALUE-2) { + throw new IllegalArgumentException("cannot handle analyzed forms > " + (Short.MAX_VALUE-2) + " in length (got " + termWithContext.length + ")"); } - payload = iterator.payload(); - // payload + surfaceLength (short) - requiredLength += payload.length + 2; - } else { - payload = null; - } - - buffer = ArrayUtil.grow(buffer, requiredLength); - - output.reset(buffer); - - output.writeShort(analyzedLength); - - output.writeBytes(scratch.bytes, scratch.offset, scratch.length); - - output.writeInt(encodeWeight(iterator.weight())); - - if (hasPayloads) { - for(int i=0;i (Short.MAX_VALUE-2)) { + throw new IllegalArgumentException("cannot handle surface form > " + (Short.MAX_VALUE-2) + " in length (got " + surfaceForm.length + ")"); } + payload = iterator.payload(); + // payload + surfaceLength (short) + requiredLength += payload.length + 2; + } else { + payload = null; } - output.writeShort((short) surfaceForm.length); - output.writeBytes(surfaceForm.bytes, surfaceForm.offset, surfaceForm.length); - output.writeBytes(payload.bytes, payload.offset, payload.length); - } else { - output.writeBytes(surfaceForm.bytes, surfaceForm.offset, surfaceForm.length); + + if (hasContexts) { + if (payload != null && payload.length > (Short.MAX_VALUE-2)) { + throw new IllegalArgumentException("cannot handle payload form > " + (Short.MAX_VALUE-2) + " in length (got " + payload.length + ")"); + } + requiredLength += contexts.length + 2; + } + + buffer = ArrayUtil.grow(buffer, requiredLength); + + output.reset(buffer); + + output.writeShort(analyzedLength); + + output.writeBytes(termWithContext.bytes, termWithContext.offset, termWithContext.length); + + output.writeInt(encodeWeight(iterator.weight())); + + if (hasPayloads || hasContexts) { + for(int i=0;i " + cost + ": " + surface.utf8ToString()); - if (!hasPayloads) { + if (!hasPayloads && !hasContexts) { builder.add(scratchInts, outputs.newPair(cost, BytesRef.deepCopyOf(surface))); } else { - int payloadOffset = input.getPosition() + surface.length; - int payloadLength = scratch.length - payloadOffset; - BytesRef br = new BytesRef(surface.length + 1 + payloadLength); - System.arraycopy(surface.bytes, surface.offset, br.bytes, 0, surface.length); - br.bytes[surface.length] = PAYLOAD_SEP; - System.arraycopy(scratch.bytes, payloadOffset, br.bytes, surface.length+1, payloadLength); - br.length = br.bytes.length; - builder.add(scratchInts, outputs.newPair(cost, br)); + if ((hasPayloads && !hasContexts) || (!hasPayloads && hasContexts)) { // if only one of payload or context is present + int payloadOffset = input.getPosition() + surface.length; + int payloadLength = scratch.length - payloadOffset; + BytesRef br = new BytesRef(surface.length + 1 + payloadLength); + System.arraycopy(surface.bytes, surface.offset, br.bytes, 0, surface.length); + br.bytes[surface.length] = PAYLOAD_SEP; + System.arraycopy(scratch.bytes, payloadOffset, br.bytes, surface.length+1, payloadLength); + br.length = br.bytes.length; + + builder.add(scratchInts, outputs.newPair(cost, br)); + } else { + input.skipBytes(surface.length); + short payloadLength = input.readShort(); + int payloadOffset = input.getPosition(); + int categoryOffset = input.getPosition() + payloadLength; + int categoryLength = scratch.length - categoryOffset; + BytesRef br = new BytesRef(surface.length + 1 + payloadLength + 1 + categoryLength); + System.arraycopy(surface.bytes, surface.offset, br.bytes, 0, surface.length); + br.bytes[surface.length] = PAYLOAD_SEP; + System.arraycopy(scratch.bytes, payloadOffset, br.bytes, surface.length+1, payloadLength); + br.bytes[surface.length + 1 + payloadLength] = PAYLOAD_SEP; + System.arraycopy(scratch.bytes, categoryOffset, br.bytes, surface.length+1+payloadLength+1, categoryLength); + br.length = br.bytes.length; + builder.add(scratchInts, outputs.newPair(cost, br)); + } } } fst = builder.finish(); @@ -580,7 +682,15 @@ public class AnalyzingSuggester extends Lookup { fst.save(dataOut); dataOut.writeVInt(maxAnalyzedPathsForOneInput); - dataOut.writeByte((byte) (hasPayloads ? 1 : 0)); + int metaData = 0; + if (hasContexts && hasPayloads) { + metaData = 2; + } else if (hasPayloads) { + metaData = 1; + } else if (hasContexts) { + metaData = 3; + } + dataOut.writeByte((byte) metaData); } finally { IOUtils.close(output); } @@ -593,7 +703,21 @@ public class AnalyzingSuggester extends Lookup { try { this.fst = new FST>(dataIn, new PairOutputs(PositiveIntOutputs.getSingleton(), ByteSequenceOutputs.getSingleton())); maxAnalyzedPathsForOneInput = dataIn.readVInt(); - hasPayloads = dataIn.readByte() == 1; + int metaData = dataIn.readByte(); + if (metaData == 0) { + hasContexts = false; + hasPayloads = false; + } else if (metaData == 1) { + hasContexts = false; + hasPayloads = true; + } else if (metaData == 2) { + hasContexts = true; + hasPayloads = true; + } else if (metaData == 3) { + hasContexts = true; + hasPayloads = false; + } else { // TODO:: throw error? + } } finally { IOUtils.close(input); } @@ -602,7 +726,7 @@ public class AnalyzingSuggester extends Lookup { private LookupResult getLookupResult(Long output1, BytesRef output2, CharsRef spare) { LookupResult result; - if (hasPayloads) { + if (hasPayloads || hasContexts) { int sepIndex = -1; for(int i=0;i contexts = new HashSet<>(); + int lastSep = sepIndex; + for(int i = sepIndex+1; i < output2.length; i++) { + if (output2.bytes[output2.offset+i] == PAYLOAD_SEP) { + int contextLen = i - lastSep - 1; + BytesRef context = new BytesRef(contextLen); + System.arraycopy(output2.bytes, lastSep+1, context.bytes, context.offset, contextLen); + context.length = context.bytes.length; + contexts.add(context); + lastSep = i; + } + } + int contextLen = output2.length - lastSep - 1; + BytesRef context = new BytesRef(contextLen); + System.arraycopy(output2.bytes, lastSep+1, context.bytes, context.offset, contextLen); + context.length = context.bytes.length; + contexts.add(context); + result = new LookupResult(spare.toString(), decodeWeight(output1), contexts); + } else { // has both payload and contexts + Set contexts = new HashSet<>(); + boolean first = true; + BytesRef payload = null; + int lastSep = sepIndex; + for(int i = sepIndex+1; i < output2.length; i++) { + if (output2.bytes[output2.offset+i] == PAYLOAD_SEP) { + if (first) { + int payloadLen = i - lastSep - 1; + payload = new BytesRef(payloadLen); + System.arraycopy(output2.bytes, lastSep+1, payload.bytes, payload.offset, payloadLen); + payload.length = payloadLen; + first = false; + } else { + int contextLen = i - lastSep - 1; + BytesRef context = new BytesRef(contextLen); + System.arraycopy(output2.bytes, lastSep+1, context.bytes, context.offset, contextLen); + context.length = context.bytes.length; + contexts.add(context); + } + lastSep = i; + } + } + int contextLen = output2.length - lastSep - 1; + BytesRef context = new BytesRef(contextLen); + System.arraycopy(output2.bytes, lastSep+1, context.bytes, context.offset, contextLen); + context.length = context.bytes.length; + contexts.add(context); + assert payload != null; + result = new LookupResult(spare.toString(), decodeWeight(output1), payload, contexts); + } } else { spare.grow(output2.length); UnicodeUtil.UTF8toUTF16(output2, spare); diff --git lucene/suggest/src/java/org/apache/lucene/search/suggest/analyzing/ContextAwareAnalyzingSuggester.java lucene/suggest/src/java/org/apache/lucene/search/suggest/analyzing/ContextAwareAnalyzingSuggester.java new file mode 100644 index 0000000..47dd0eb --- /dev/null +++ lucene/suggest/src/java/org/apache/lucene/search/suggest/analyzing/ContextAwareAnalyzingSuggester.java @@ -0,0 +1,161 @@ +package org.apache.lucene.search.suggest.analyzing; + +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import java.io.IOException; +import java.io.InputStream; +import java.io.OutputStream; +import java.util.Arrays; +import java.util.HashMap; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Set; + +import org.apache.lucene.analysis.Analyzer; +import org.apache.lucene.search.suggest.InputIterator; +import org.apache.lucene.search.suggest.Lookup; +import org.apache.lucene.util.BytesRef; +import org.apache.lucene.util.CharsRef; + +/** + * Suggester build on {@link AnalyzingSuggester}. This suggester is + * aware of contexts and will only suggest terms which were in the + * lookup context. + * @lucene.experimental + */ +public class ContextAwareAnalyzingSuggester extends Lookup { + + private final static int CONTEXT_SEP = '\u001c'; + private final AnalyzingSuggester analyzingSuggester; + /** + * Calls context aware {@link #ContextAwareAnalyzingSuggester(Analyzer,Analyzer,int,int,int,boolean) + * AnalyzingSuggester(analyzer, analyzer, EXACT_FIRST | + * PRESERVE_SEP, 256, -1, true)} + */ + public ContextAwareAnalyzingSuggester(Analyzer analyzer) { + this(analyzer, analyzer, AnalyzingSuggester.EXACT_FIRST | AnalyzingSuggester.PRESERVE_SEP, 256, -1, true); + } + + /** + * Calls context aware {@link #ContextAwareAnalyzingSuggester(Analyzer,Analyzer,int,int,int,boolean) + * AnalyzingSuggester(indexAnalyzer, queryAnalyzer, EXACT_FIRST | + * PRESERVE_SEP, 256, -1, true)} + */ + public ContextAwareAnalyzingSuggester(Analyzer indexAnalyzer, Analyzer queryAnalyzer) { + this(indexAnalyzer, queryAnalyzer, AnalyzingSuggester.EXACT_FIRST | AnalyzingSuggester.PRESERVE_SEP, 256, -1, true); + } + + /** + * Creates a new context aware AnalyzingSuggester suggester + */ + public ContextAwareAnalyzingSuggester(Analyzer indexAnalyzer, Analyzer queryAnalyzer, int options, int maxSurfaceFormsPerAnalyzedForm, int maxGraphExpansions, + boolean preservePositionIncrements) { + this.analyzingSuggester = new AnalyzingSuggester(indexAnalyzer, queryAnalyzer, options, maxSurfaceFormsPerAnalyzedForm, maxGraphExpansions, + preservePositionIncrements) { + @Override + protected Set getTermsWithContext(Set context, BytesRef scratch) { + return ContextAwareAnalyzingSuggester.this.getTermsWithContext(context, scratch); + } + }; + } + + private Set getTermsWithContext(Set contexts, BytesRef scratch) { + Set termEntries = new HashSet<>(); + if (contexts == null || contexts.size() == 0) { + throw new IllegalArgumentException("Missing context"); + } + + for (BytesRef context : contexts) { + for (int i = 0; i < context.length; i++) { + if (context.bytes[i] == CONTEXT_SEP) { + throw new IllegalArgumentException("context cannot contain unit separator character U+001C; this character is reserved"); + } + } + BytesRef termWithContext = new BytesRef(context.length + 1 + scratch.length); + termWithContext.copyBytes(context); + termWithContext.bytes[termWithContext.offset + context.length] = CONTEXT_SEP; + termWithContext.length++; + termWithContext.append(scratch); + termEntries.add(termWithContext); + } + + return termEntries; + } + + private void constructLookupKey(CharSequence contextCharSeq, CharSequence key, BytesRef keyWithContext) { + keyWithContext.copyChars(contextCharSeq); + keyWithContext.bytes[keyWithContext.offset + contextCharSeq.length()] = CONTEXT_SEP; + keyWithContext.length++; + keyWithContext.append(new BytesRef(key)); + } + + @Override + public List lookup(final CharSequence key, Set contexts, int num) { + assert num > 0; + Set keySet = new HashSet<>(); + for (int i = 0; i < key.length(); i++) { + if (key.charAt(i) == 0x1c) { + throw new IllegalArgumentException("lookup key cannot contain unit separator character U+001C; this character is reserved"); + } + } + if (contexts.size() == 1) { + CharSequence contextCharSeq = contexts.iterator().next(); + BytesRef context = new BytesRef(contextCharSeq.length() + 1 + key.length()); + constructLookupKey(contextCharSeq, key, context); + return analyzingSuggester.lookup(new CharsRef(context.utf8ToString()), false, num); + } else { // has multiple contexts + Lookup.LookupPriorityQueue results = new LookupPriorityQueue(num); + for (CharSequence contextCharSeq : contexts) { + BytesRef context = new BytesRef(contextCharSeq.length() + 1 + key.length()); + constructLookupKey(contextCharSeq, key, context); + for (LookupResult r : analyzingSuggester.lookup(new CharsRef(context.utf8ToString()), false, num)) { + if (!keySet.contains(r.key)) { + keySet.add(r.key.toString()); + } else { + continue; + } + results.insertWithOverflow(r); + } + } + + return Arrays.asList(results.getResults()); + } + } + + @Override + public void build(InputIterator tfit) throws IOException { + analyzingSuggester.build(tfit); + } + + @Override + public boolean store(OutputStream output) throws IOException { + return analyzingSuggester.store(output); + } + + @Override + public boolean load(InputStream input) throws IOException { + return analyzingSuggester.load(input); + } + + @Override + public long sizeInBytes() { + return analyzingSuggester.sizeInBytes(); + } + +} diff --git lucene/suggest/src/test/org/apache/lucene/search/suggest/Input.java lucene/suggest/src/test/org/apache/lucene/search/suggest/Input.java index 009f80c..e8834ce 100644 --- lucene/suggest/src/test/org/apache/lucene/search/suggest/Input.java +++ lucene/suggest/src/test/org/apache/lucene/search/suggest/Input.java @@ -17,6 +17,9 @@ package org.apache.lucene.search.suggest; * limitations under the License. */ +import java.util.HashSet; +import java.util.Set; + import org.apache.lucene.util.BytesRef; /** corresponds to {@link InputIterator}'s entries */ @@ -25,28 +28,55 @@ public final class Input { public final long v; public final BytesRef payload; public final boolean hasPayloads; + public final Set contexts; + public final boolean hasContexts; public Input(BytesRef term, long v, BytesRef payload) { - this(term, v, payload, true); + this(term, v, payload, true, null, false); } public Input(String term, long v, BytesRef payload) { - this(new BytesRef(term), v, payload, true); + this(new BytesRef(term), v, payload); + } + + public Input(BytesRef term, long v, Set contexts) { + this(term, v, null, false, contexts, true); + } + + public Input(String term, long v, Set contexts) { + this(new BytesRef(term), v, null, false, contexts, true); } public Input(BytesRef term, long v) { - this(term, v, null, false); + this(term, v, null, false, null, false); } public Input(String term, long v) { - this(new BytesRef(term), v, null, false); + this(new BytesRef(term), v, null, false, null, false); } - public Input(BytesRef term, long v, BytesRef payload, boolean hasPayloads) { + public Input(String term, int v, BytesRef payload, Set contexts) { + this(new BytesRef(term), v, payload, true, contexts, true); + } + + public Input(BytesRef term, long v, BytesRef payload, Set contexts) { + this(term, v, payload, true, contexts, true); + } + + + + public Input(BytesRef term, long v, BytesRef payload, boolean hasPayloads, Set contexts, + boolean hasContexts) { this.term = term; this.v = v; this.payload = payload; this.hasPayloads = hasPayloads; + this.contexts = contexts; + this.hasContexts = hasContexts; + } + + public boolean hasContexts() { + return hasContexts; } public boolean hasPayloads() { diff --git lucene/suggest/src/test/org/apache/lucene/search/suggest/InputArrayIterator.java lucene/suggest/src/test/org/apache/lucene/search/suggest/InputArrayIterator.java index edebb37..75301e9 100644 --- lucene/suggest/src/test/org/apache/lucene/search/suggest/InputArrayIterator.java +++ lucene/suggest/src/test/org/apache/lucene/search/suggest/InputArrayIterator.java @@ -19,6 +19,7 @@ package org.apache.lucene.search.suggest; import java.util.Arrays; import java.util.Iterator; +import java.util.Set; import org.apache.lucene.util.BytesRef; @@ -28,6 +29,7 @@ import org.apache.lucene.util.BytesRef; public final class InputArrayIterator implements InputIterator { private final Iterator i; private final boolean hasPayloads; + private final boolean hasContexts; private boolean first; private Input current; private final BytesRef spare = new BytesRef(); @@ -38,8 +40,10 @@ public final class InputArrayIterator implements InputIterator { current = i.next(); first = true; this.hasPayloads = current.hasPayloads; + this.hasContexts = current.hasContexts; } else { this.hasPayloads = false; + this.hasContexts = false; } } @@ -78,4 +82,14 @@ public final class InputArrayIterator implements InputIterator { public boolean hasPayloads() { return hasPayloads; } + + @Override + public Set contexts() { + return current.contexts; + } + + @Override + public boolean hasContexts() { + return hasContexts; + } } \ No newline at end of file diff --git lucene/suggest/src/test/org/apache/lucene/search/suggest/analyzing/ContextAwareAnalyzingSuggesterTest.java lucene/suggest/src/test/org/apache/lucene/search/suggest/analyzing/ContextAwareAnalyzingSuggesterTest.java new file mode 100644 index 0000000..bb5e7aa --- /dev/null +++ lucene/suggest/src/test/org/apache/lucene/search/suggest/analyzing/ContextAwareAnalyzingSuggesterTest.java @@ -0,0 +1,170 @@ +package org.apache.lucene.search.suggest.analyzing; + +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.HashSet; +import java.util.List; +import java.util.Set; + +import org.apache.lucene.analysis.MockAnalyzer; +import org.apache.lucene.analysis.MockTokenizer; +import org.apache.lucene.search.suggest.Lookup.LookupResult; +import org.apache.lucene.search.suggest.Input; +import org.apache.lucene.search.suggest.InputArrayIterator; +import org.apache.lucene.util.BytesRef; +import org.apache.lucene.util.LuceneTestCase; +import org.apache.lucene.util._TestUtil; + +public class ContextAwareAnalyzingSuggesterTest extends LuceneTestCase { + + public void testWithPayload() throws Exception { + Iterable keysWithPayload = shuffle( + new Input("pfoo", 50, new BytesRef("payload1"), new HashSet(Arrays.asList(new BytesRef("blah"), new BytesRef("blah1")))), + new Input("pfao", 50, new BytesRef("payload22"), new HashSet(Arrays.asList(new BytesRef("blah")))), + new Input("pflo", 50, new BytesRef("payload333"), new HashSet(Arrays.asList(new BytesRef("context")))) + ); + + ContextAwareAnalyzingSuggester suggester = new ContextAwareAnalyzingSuggester(new MockAnalyzer(random(), MockTokenizer.KEYWORD, false)); + suggester.build(new InputArrayIterator(keysWithPayload)); + + List results = suggester.lookup(_TestUtil.stringToCharSequence("pf", random()), new HashSet(Arrays.asList("blah1")), 2); + assertEquals(1, results.size()); + assertEquals("pfoo", results.get(0).key.toString()); + assertContexts(new String[]{"blah", "blah1"}, results.get(0).contexts); + assertEquals("payload1", results.get(0).payload.utf8ToString()); + assertEquals(50, results.get(0).value, 0.01F); + } + + public void testOnlyContext() throws Exception { + Iterable keys = shuffle( + new Input("pfoo", 50, new HashSet(Arrays.asList(new BytesRef("blah"), new BytesRef("blah1")))), + new Input("foo", 50, new HashSet(Arrays.asList(new BytesRef("blah"), new BytesRef("blah1")))), + new Input("flu", 10, new HashSet(Arrays.asList(new BytesRef("context")))) + ); + + ContextAwareAnalyzingSuggester suggester = new ContextAwareAnalyzingSuggester(new MockAnalyzer(random(), MockTokenizer.KEYWORD, false)); + suggester.build(new InputArrayIterator(keys)); + + List results = suggester.lookup(_TestUtil.stringToCharSequence("f", random()), new HashSet(Arrays.asList("blah")), 2); + assertEquals(1, results.size()); + assertEquals("foo", results.get(0).key.toString()); + assertContexts(new String[]{"blah", "blah1"}, results.get(0).contexts); + assertEquals(50, results.get(0).value, 0.01F); + + results = suggester.lookup(_TestUtil.stringToCharSequence("f", random()), new HashSet(Arrays.asList("blah1")), 2); + assertEquals(1, results.size()); + assertEquals("foo", results.get(0).key.toString()); + assertContexts(new String[]{"blah", "blah1"}, results.get(0).contexts); + assertEquals(50, results.get(0).value, 0.01F); + + // lookup key was a context + results = suggester.lookup(_TestUtil.stringToCharSequence("contex", random()), new HashSet(Arrays.asList("")), 2); + assertEquals(0, results.size()); + + // context is a subset of another context + results = suggester.lookup(_TestUtil.stringToCharSequence("1", random()), new HashSet(Arrays.asList("blah")), 2); + assertEquals(0, results.size()); + + } + + public void testMultiResults() throws Exception { + Iterable keys = shuffle( + new Input("foo", 50, new HashSet(Arrays.asList(new BytesRef("blah"), new BytesRef("blah1")))), + new Input("foa", 10, new HashSet(Arrays.asList(new BytesRef("blah3"), new BytesRef("blah")))), + new Input("bar", 5, new HashSet(Arrays.asList(new BytesRef("context")))) + ); + + ContextAwareAnalyzingSuggester suggester = new ContextAwareAnalyzingSuggester(new MockAnalyzer(random(), MockTokenizer.KEYWORD, false)); + suggester.build(new InputArrayIterator(keys)); + + List results = suggester.lookup(_TestUtil.stringToCharSequence("fo", random()), new HashSet(Arrays.asList("blah")), 2); + assertEquals(2, results.size()); + assertEquals("foo", results.get(0).key.toString()); + assertEquals("foa", results.get(1).key.toString()); + assertEquals(50, results.get(0).value, 0.01F); + assertEquals(10, results.get(1).value, 0.01F); + assertContexts(new String[] {"blah", "blah1"}, results.get(0).contexts); + assertContexts(new String[] {"blah", "blah3"}, results.get(1).contexts); + } + + public void testMultiContexts() throws Exception { + Iterable keys = shuffle( + new Input("foo", 50, new HashSet(Arrays.asList(new BytesRef("Greece"), new BytesRef("Germany")))), + new Input("foa", 10, new HashSet(Arrays.asList(new BytesRef("Bangladesh"), new BytesRef("Brazil")))), + new Input("bar", 5, new HashSet(Arrays.asList(new BytesRef("context")))) + ); + + ContextAwareAnalyzingSuggester suggester = new ContextAwareAnalyzingSuggester(new MockAnalyzer(random(), MockTokenizer.KEYWORD, false)); + suggester.build(new InputArrayIterator(keys)); + + List results = suggester.lookup(_TestUtil.stringToCharSequence("fo", random()), new HashSet(Arrays.asList("Bangladesh", "Greece")), 2); + assertEquals(2, results.size()); + assertEquals("foo", results.get(0).key.toString()); + assertEquals("foa", results.get(1).key.toString()); + assertEquals(50, results.get(0).value, 0.01F); + assertEquals(10, results.get(1).value, 0.01F); + assertContexts(new String[] {"Greece", "Germany"}, results.get(0).contexts); + assertContexts(new String[] {"Bangladesh", "Brazil"}, results.get(1).contexts); + } + + public void testOverlappingContexts() throws Exception { + Iterable keys = shuffle( + new Input("foo", 50, new HashSet(Arrays.asList(new BytesRef("Greece"), new BytesRef("Germany")))), + new Input("foa", 10, new HashSet(Arrays.asList(new BytesRef("Bangladesh"), new BytesRef("Germany")))), + new Input("bar", 5, new HashSet(Arrays.asList(new BytesRef("context")))) + ); + + ContextAwareAnalyzingSuggester suggester = new ContextAwareAnalyzingSuggester(new MockAnalyzer(random(), MockTokenizer.KEYWORD, false)); + suggester.build(new InputArrayIterator(keys)); + + List results = suggester.lookup(_TestUtil.stringToCharSequence("fo", random()), new HashSet(Arrays.asList("Germany", "Greece")), 2); + + assertEquals(2, results.size()); + assertEquals("foo", results.get(0).key.toString()); + assertEquals("foa", results.get(1).key.toString()); + assertEquals(50, results.get(0).value, 0.01F); + assertEquals(10, results.get(1).value, 0.01F); + assertContexts(new String[] {"Greece", "Germany"}, results.get(0).contexts); + assertContexts(new String[] {"Bangladesh", "Germany"}, results.get(1).contexts); + } + + private void assertContexts(String[] expectedContexts, Set contexts) { + Set actualContexts = new HashSet<>(); + assertTrue(contexts.size() == expectedContexts.length); + for (BytesRef ctx : contexts) { + actualContexts.add(ctx.utf8ToString()); + } + for (String expectedContext : expectedContexts) { + assertTrue(actualContexts.contains(expectedContext)); + } + } + + @SafeVarargs + public final Iterable shuffle(T...values) { + final List asList = new ArrayList(values.length); + for (T value : values) { + asList.add(value); + } + Collections.shuffle(asList, random()); + return asList; + } + +} diff --git lucene/suggest/src/test/org/apache/lucene/search/suggest/analyzing/TestFreeTextSuggester.java lucene/suggest/src/test/org/apache/lucene/search/suggest/analyzing/TestFreeTextSuggester.java index 7d3e3cc..8338093 100644 --- lucene/suggest/src/test/org/apache/lucene/search/suggest/analyzing/TestFreeTextSuggester.java +++ lucene/suggest/src/test/org/apache/lucene/search/suggest/analyzing/TestFreeTextSuggester.java @@ -171,6 +171,16 @@ public class TestFreeTextSuggester extends LuceneTestCase { public boolean hasPayloads() { return false; } + + @Override + public Set contexts() { + return null; + } + + @Override + public boolean hasContexts() { + return false; + } }); if (VERBOSE) { System.out.println(sug.sizeInBytes() + " bytes"); @@ -362,6 +372,16 @@ public class TestFreeTextSuggester extends LuceneTestCase { public boolean hasPayloads() { return false; } + + @Override + public Set contexts() { + return null; + } + + @Override + public boolean hasContexts() { + return false; + } }); // Build inefficient but hopefully correct model: