Index: gen.py =================================================================== --- gen.py (revision 0) +++ gen.py (revision 0) @@ -0,0 +1,1286 @@ +import sys +import os + +# TODO +# - make custom BQ rewrite logic to "collapse" corner cases to their simpler eq's +# - call query.rewrite first in fast search +# - try tricks eg make "final" local var copy of class static var +# - conditionalize Bucket class vs parallel arrays +# - don't do fast reject for TermQuery... or do it, but then don't check "competes" again +# - hurm: do the "random access JumpScorer" optimization +# - make a simple "gen" mode that has a "learning" step where you +# send many queries through and it simply gathers data on which +# specs are needed, followed by building those specs +# - more strongly model "I am top collector loop" vs "I am secondary" +# - split out query gen code that's "in the to loop" vs "a +# sub-clause", ie, can do compete/score/collect itself or must +# strongly separate doc vs score +# - allow per-segment specializing +# - eg deletes/no +# - clean up this source code +# - make verify.cmd faster -- single bnechmark run? +# - run all tests, routing queries through spec +# - make this more "incremental" eg if I can't specialize the scorer but I can specialize the collector, do so +# - eg handle custom FieldComparator (must call setScorer now), custom Collector +# - provide a "my field has valid sentinel values" to force docID tie breakers +# - mabye do this as part of a s "single warmup" step that does +# things like check for no-nulls in FieldCache for string, no +# max/min val in numeric fields, etc +# - "sort by docid", "sort by score" +# - omitTF +# - make "valueType" a comparator +# - OR-qery +# - gen minNrShouldMatch variants +# - relax to N clauses +# - allow prohibited, must terms +# - implement BooleanScorer2 (OR/AND scorers) +# - don't make double-while loop when doing next() +# - re-instate "i am top query" optimization, eg for termquery +# - collapse the while loops +# - RelevanceComparator is broken +# - we can skip freq, instead of reading into freq var, if hit does not compete +# - when score is "no" we should have single "skip freq vint" read +# - maybe use a custom "skipVInt" method -- reads the bytes w/o building int +# - FIX score: don't compute it we are not tracking maxScore +# - support sparse filters (as iterator, boolean query clause?) +# - reversed sort +# - multifield sort +# - merge in all stop criteria, eg ++count > aLimit +# - for scoredoc we do't need to re-add docBase -- just init doc=0 +# - hmm: will we "flood" the instruction cache when many threads running different specialized code? +# - share norms per field (if multiple term queries run against same field) +# - for real integration +# - need to "know" that filter will give us an openbitset, somehow? +# - tweak order of ifs +# - TODO: figure out if side-by-side arrays are slower, or +# pointer-to-obj is slower, for the queue +# - score doc & fields +# - downHeap could be optimized -- that initial "if k < 10" +# - for TopScoreDoc collection, try to re-coalesce multiple if statemetns (it got slower!!) + +END_DOC = sys.maxint + +class Writer: + + def __init__(self): + self.indent = 0 + self.l = [] + self.vars = {} + self.upto = 0 + + def __call__(self, s): + if s.find('}') != -1: + self.indent -= 1 + c = 0 + while c < len(s)-1 and s[c] == ' ': + c += 1 + self.l.append(' '*(2*self.indent-c) + s) + if s.find('{') != -1: + self.indent += 1 + + def __str__(self): + return '\n'.join(self.l) + + def getVar(self, prefix): + upto = 0 + while True: + v = '%s%s' % (prefix, upto) + if v not in self.vars: + self.vars[v] = True + return v + upto += 1 + + def getID(self): + s ='q%d' % self.upto + self.upto += 1 + return s + + def releaseVar(self, v): + del self.vars[v] + +class BasePQCollector: + + valueType = None + docsAlwaysInOrder = True + + def __init__(self, w): + self.w = w + + def setState(self, s): + self.s = s + + def upperValueType(self, isFieldCache=False): + if self.valueType == 'int': + if isFieldCache: + return 'Int' + else: + return 'Integer' + else: + return self.valueType[0].upper() + self.valueType[1:] + + def readerInit(self): + pass + + def writeTempBottom(self): + pass + + def writeTopClass(self): + pass + + def topInit(self): + if self.s.doTotalHits: + self.w('int hitCount = 0;') + + def collectOne(self, q, checkEndDoc=True, docsInOrder=True): + w = self.w + if checkEndDoc: + w('if (doc == %s) {' % END_DOC) + w(' break;') + w('}') + if self.s.doTotalHits: + w('hitCount++;') + + if DEBUG: + w('System.out.println("doc=" + (doc+docBase) + " score=" + %s);' % q.scoreVar) + + if not docsInOrder: + w(' final int fullDoc = doc + docBase;') + + #w('System.out.println("rx=" + rx + " doc=" + (doc+docBase) + " ord=" + order[doc] + " val=" + lookup[order[doc]] + " bottom.ord=" + bottom.ord + " bottom.val=" + bottom.val);') + self.writeCompetes(q, docsInOrder) + + w('') + + if docsInOrder: + w(' final int fullDoc = doc + docBase;') + + self.copyCurrentToBottom() + + w('') + w(' // downheap') + + iVar = w.getVar('i') + w(' int %s = 1;' % iVar) + + jVar = w.getVar('j') + w(' int %s = %s << 1;' % (jVar, iVar)) + + kVar = w.getVar('k') + w(' int %s = %s+1;' % (kVar, jVar)) + w(' if (%s <= topN) {' % kVar) + self.lessThan(kVar, jVar, 'lt') + #w(' System.out.println("init lt=" + lt);') + w(' if (lt) {') + w(' %s = %s;' % (jVar, kVar)) + w(' }') + w(' }') + w(' while(%s <= topN) {' % jVar) + self.currentGreaterThanBreak(jVar, docsInOrder) + self.copy(jVar, iVar) + w(' %s = %s;' % (iVar, jVar)) + w(' %s = %s << 1;' % (jVar, iVar)) + w(' %s = %s+1;' % (kVar, jVar)) + w(' if (%s <= topN) {' % kVar) + self.lessThan(kVar, jVar, 'lt') + w(' if (lt) {') + w(' %s = %s;' % (jVar, kVar)) + w(' }') + w(' }') + w(' }') + w.releaseVar(kVar) + w.releaseVar(jVar) + + self.installBottom(iVar) + w.releaseVar(iVar) + + self.endInsert() + + # matches writeCompetes + w('}') + + def createResults(self, v, c): + w = self.w + + w('') + if self.__class__ == ScoreDocCollector: + self.w('final float maxScore = queueScores[1];') + + w('// Build results -- sort pqueue entries') + w('final SorterTemplate sorter = new SorterTemplate() {') + w(' protected int compare(int i, int j) {') + self.lessThan('i', 'j', 'lt') + w(' if (lt) {') + w(' return 1;') + w(' } else {') + w(' // pq entries are never equal') + w(' return -1;') + w(' }') + w(' }') + w(' protected void swap(int i, int j) {') + self.swap('i', 'j') + w(' }') + w('};') + + w('// Extract results') + w('sorter.quickSort(1, topN);') + w('ScoreDoc[] hits = new ScoreDoc[topN];') + w('for(int i=0;i= score) {' % v) + else: + self.w('if (queueScores[%s] > score || (queueScores[%s] == score && queueDocs[%s] < fullDoc)) {' % (v, v, v)) + self.w(' break;') + self.w('}') + + def lessThan(self, a, b, v): + self.w('final boolean %s = queueScores[%s] < queueScores[%s] || (queueScores[%s] == queueScores[%s] && queueDocs[%s] > queueDocs[%s]);' % (v, a, b, a, b, a, b)) + + def copy(self, src, dest): + self.w('queueDocs[%s] = queueDocs[%s];' % (dest, src)) + self.w('queueScores[%s] = queueScores[%s];' % (dest, src)) + + def installBottom(self, dest): + self.w('queueDocs[%s] = fullDoc;' % dest) + self.w('queueScores[%s] = score;' % dest) + + def copyCurrentToBottom(self): + if DEBUG: + self.w('System.out.println(" boot doc=" + queueDocs[1] + " for new doc=" + fullDoc + " (eq?=" + (score == bottomScore) + ")" + " score=" + score + " bottomScore=" + bottomScore);') + self.w('queueDocs[1] = fullDoc;') + self.w('queueScores[1] = score;') + + def endInsert(self): + self.w('bottomScore = queueScores[1];') + if not self.docsAlwaysInOrder: + self.w('bottomDoc = queueDocs[1];') + if DEBUG: + self.w('System.out.println(" bottom=" + bottomScore + " doc=" + queueDocs[1]);') + + def writeCompetes(self, q, docsInOrder=True): + q.writeScore() + if docsInOrder: + self.w('if (score > bottomScore) {') + else: + self.w('if (score > bottomScore || (score == bottomScore && fullDoc < bottomDoc)) {') + + def swap(self, i, j): + w = self.w + w('final int itmp = queueDocs[%s];' % i) + w('queueDocs[%s] = queueDocs[%s];' % (i, j)) + w('queueDocs[%s] = itmp;' % j) + w('final float ftmp = queueScores[%s];' % i) + w('queueScores[%s] = queueScores[%s];' % (i, j)) + w('queueScores[%s] = ftmp;' % j) + + def createTopDocs(self): + if self.s.doTotalHits: + v = 'hitCount' + else: + v = '-1' + self.w('final TopDocs results = new TopDocs(%s, hits, queueScores[1]);' % v) + + +class StringOrdValComparator: + + def __init__(self, w, replaceNulls): + self.w = w + self.replaceNulls = replaceNulls + + def getRejectDoc(self): + return 'order[__DOC__] > bottom.ord' + + def setFieldValue(self, src): + if self.replaceNulls: + self.w('if (queue[1+%s].val.equals("\u0000")) {' % src) + self.w(' fields[0] = null;') + self.w('} else {') + self.w('fields[0] = queue[1+%s].val;' % src) + if self.replaceNulls: + self.w('}') + + def writeCompetesDef(self): + w = self.w + w('private final int competes(final Entry bottom, final int ord, final String val) {') + w(' // ords are comparable') + #w(' System.out.println("competes bottom.ord=" + bottom.ord + " vs docOrd=" + ord + " val=" + val);') + w(' final int c = bottom.ord - ord;') + w(' if (c != 0) {') + w(' return c;') + w(' } else {') + if self.replaceNulls: + w(' return bottom.val.compareTo(val);') + else: + w(' if (bottom.val == null) {') + w(' if (val == null) {') + w(' return 0;') + w(' } else {') + w(' return -1;') + w(' }') + w(' } else if (val == null) {') + w(' return 1;') + w(' } else {') + w(' return bottom.val.compareTo(val);') + w(' }') + w(' }') + w('}') + + def writeCompetes(self, q, docsInOrder): + w = self.w + #w('System.out.println("collect doc=" + doc);') + if docsInOrder: + w('if (competes(bottom, order[doc], lookup[order[doc]]) > 0) {') + #w('System.out.println(" competes! doc=" + doc + " val=" + lookup[order[doc]] + " bottom.ord=" + bottom.ord + " bottom.val=" + bottom.val);') + else: + w('final int cx = competes(bottom, order[doc], lookup[order[doc]]);') + w('if (cx > 0 || (cx == 0 && fullDoc < bottom.docID)) {') + + def currentGreaterThanBreak(self, v, docsInOrder=True): + w = self.w + w('final int c = compareOrdVal(queue[%s], bottom);' % v) + if docsInOrder: + w('if (c <= 0) {') + #w(' System.out.println(" now break k0=" + k0 + " j0=" + j0 + " c=" + c);') + else: + w('if (c < 0 || (c == 0 && queue[%s].docID < bottom.docID)) {' % v) + w(' break;') + w('}') + + def copyCurrentToBottom(self): + self.w('bottom.ord = order[doc];') + self.w('bottom.val = lookup[bottom.ord];') + self.w('bottom.readerGen = rx;') + + def initSingleEntryQueue(self): + self.w('e.ord = Integer.MAX_VALUE;') + self.w('e.val = sentinel;') + + def writeTop(self): + self.w('// U+FFFF is not a valid unicode char, so this is a safe sentinel') + self.w('final String sentinel = new String("\uFFFF");') + + def endInsert(self): + self.writeConvertBottom() + + def writeCompare(self): + w = self.w + w('private final int compareOrdVal(Entry e1, Entry e2) {') + #w(' System.out.println("compareOrdVal e1.readerGen=" + e1.readerGen + " e2.readerGen=" + e2.readerGen);') + #w(' System.out.println("compareOrdVal e1.val=" + e1.val + " e2.val=" + e2.val);') + w(' if (e1.readerGen == e2.readerGen) {') + w(' // ords are comparable') + w(' final int c = e1.ord - e2.ord;') + #w(' System.out.println(" c=" + c);') + w(' if (c != 0) {') + w(' return c;') + w(' }') + w(' }') + if not self.replaceNulls: + w(' if (e1.val == null) {') + w(' if (e2.val == null) {') + w(' return 0;') + w(' } else {') + w(' return -1;') + w(' }') + w(' } else if (e2.val == null) {') + w(' return 1;') + w(' }') + w(' return e1.val.compareTo(e2.val);') + w('}') + + def writeBinarySearch(self): + self.w(''' final private static int binarySearch(String[] a, String key) { + return binarySearch(a, key, 0, a.length-1); + }''') + + self.w(''' final private static int binarySearch(String[] a, String key, int low, int high) { + + while (low <= high) { + int mid = (low + high) >>> 1; + String midVal = a[mid]; + int cmp;''') + + if self.replaceNulls: + self.w('cmp = midVal.compareTo(key);') + else: + self.w(''' + if (midVal != null) { + cmp = midVal.compareTo(key); + } else { + cmp = -1; + }''') + self.w(''' + if (cmp < 0) + low = mid + 1; + else if (cmp > 0) + high = mid - 1; + else + return mid; + } + return -(low + 1); + }''') + + def writeTopClass(self): + self.writeCompare() + self.writeCompetesDef() + self.writeBinarySearch() + if self.replaceNulls: + self.w('final HashSet fixedNulls = new HashSet();') + + def writeEntryClass(self): + self.w('int ord;') + self.w('String val;') + self.w('int readerGen;') + + def writePerReader(self): + self.w('StringIndex sx = ExtendedFieldCache.EXT_DEFAULT.getStringIndex(r, sortFieldID);') + self.w('final String[] lookup = sx.lookup;') + if self.replaceNulls: + self.w(''' + // nocommit -- not clean, not threadsafe, not right (needs to be tuple of reader & field) + if (!fixedNulls.contains(r)) { + fixedNulls.add(r); + final int size = lookup.length; + for(int sxi=0;sxi queue[%s].docID;' % (v, a, b)) + w('} else {') + if self.comp is not None or self.valueType is None: + w(' %s = %s > 0;' % (v, c)) + else: + w(' %s = queue[%s].value > queue[%s].value;' % (v, a, b)) + w('}') + w.releaseVar(c) + + def copy(self, src, dest): + self.w('queue[%s] = queue[%s];' % (dest, src)) + + def installBottom(self, dest): + self.w('queue[%s] = bottom;' % dest) + + def copyCurrentToBottom(self): + self.w('bottom.docID = fullDoc;') + if self.doTrackScores: + self.w('bottom.score = score;') + if self.comp is not None: + self.comp.copyCurrentToBottom() + elif self.valueType is not None: + self.w('bottomValue = bottom.value = docValues[doc];') + else: + if self.doTrackScores: + s = 'score' + else: + s = 'Float.NaN' + self.w('comp.copy(bottom.slot, doc, %s);' % s) + + def endInsert(self): + self.w('bottom = queue[1];') + if self.comp is not None: + self.comp.endInsert() + elif self.valueType is not None: + self.w('bottomValue = bottom.value;') + else: + self.w('comp.setBottom(bottom.slot);') + + def writeCompetes(self, q, docsInOrder=True): + w = self.w + + if self.doMaxScore: + q.writeScore() + w(' if (score > maxScore) {') + w(' maxScore = score;') + w(' }') + + if self.comp is not None: + self.comp.writeCompetes(q, docsInOrder) + elif self.valueType is None: + w(' final int cmp = comp.compareBottom(doc, Float.NaN);') + w(' if (cmp > 0) {') + elif docsInOrder: + w(' if (docValues[doc] < bottomValue) {') + else: + w(' if (docValues[doc] < bottomValue || (docValues[doc] == bottomValue && fullDoc < bottom.docID)) {') + + if self.doTrackScores: + q.writeScore() + + if DEBUG: + w(' System.out.println(" does compete bottom");') + + def swap(self, i, j): + w = self.w + w('final Entry tmp = queue[%s];' % i) + w('queue[%s] = queue[%s];' % (i, j)) + w('queue[%s] = tmp;' % j) + + def createTopDocs(self): + if self.doMaxScore: + x = 'maxScore' + else: + x = 'Float.NaN' + if self.s.doTotalHits: + v = 'hitCount' + else: + v = '-1' + self.w('final TopDocs results = new TopFieldDocs(%s, hits, sort.getSort(), %s);' % (v, x)) + +class Query: + + def var(self, name): + return '%s%s%s' % (self.id, name[0].upper(), name[1:]) + + def pushState(self, s): + self.s = s + for query in self.getSubQueries(): + query.pushState(s) + + def setVars(self, docVar, scoreVar): + self.docVar = docVar + self.scoreVar = scoreVar + for query in self.getSubQueries(): + query.setVars(query.var('doc'), + query.var('score')) + +class NClauseOrQuery(Query): + + def __init__(self, shouldClauses, mustNotClauses, minNR): + self.shouldQueries = shouldClauses + self.mustNotQueries = mustNotClauses + self.minNR = minNR + + def getSubQueries(self): + return self.shouldQueries + self.mustNotQueries + + def writeTopClass(self, c): + w = self.w + self.c = c + c.docsAlwaysInOrder = False + if not self.s.doBucketArray: + w('private static final class Bucket {') + w(' int doc = -1;') + if c.doMaxScore: + w(' float score;') + elif c.doTrackScores: + for q in self.shouldQueries: + q.writeEntryClassData() + + if c.doMaxScore or c.doTrackScores or self.minNR != 0: + w(' int coord;') + if len(self.mustNotQueries) > 0: + w(' boolean reject;') + #w(' int bits;') + w('}') + w('private final static Bucket[] buckets = new Bucket[%d];' % self.CHUNK) + w('private final static int[] usedBuckets = new int[%d];' % self.CHUNK) + w('static {') + w(' for(int i=0;i<%d;i++) {' % self.CHUNK) + w(' buckets[i] = new Bucket();') + w(' }') + w('}') + else: + w('private static final int[] bucketDocs = new int[%d];' % self.CHUNK) + w('private static final int[] usedBuckets = new int[%d];' % self.CHUNK) + if c.doMaxScore: + w('private static final float[] bucketScores = new float[%d];' % self.CHUNK) + elif c.doTrackScores: + for q in self.shouldQueries: + q.writeEntryClassData(self.CHUNK) + if c.doMaxScore or c.doTrackScores or self.minNR != 0: + w('private static final int[] bucketCoords = new int[%d];' % self.CHUNK) + if len(self.mustNotQueries) > 0: + w('private static final boolean[] bucketReject = new boolean[%d];' % self.CHUNK) + w('static {') + w(' Arrays.fill(bucketDocs, -1);') + w('}') + + def setWriter(self, w): + self.w = w + self.id = w.getID() + for q in self.shouldQueries+self.mustNotQueries: + q.setWriter(w) + + didScore = False + def writeScore(self): + if self.didScore: + return + self.didScore = True + w = self.w + c = self.c + if c.doTrackScores or c.doMaxScore: + if not c.doMaxScore: + # score computation deferred + self.w('float score = 0f;') + l = self.shouldQueries[:] + l.reverse() + for q in l: + q.writeScore() + if not self.s.doBucketArray: + w(' score *= coordFactors[b.coord];') + else: + w(' score *= coordFactors[bucketCoords[spot]];') + else: + if not self.s.doBucketArray: + w(' final float score = b.score * coordFactors[b.coord];') + else: + w(' final float score = bucketScores[spot] * coordFactors[bucketCoords[spot]];') + + def writeTop(self, c, qVar, scores=True): + w = self.w + w('final BooleanClause[] clauses = ((BooleanQuery) q).getClauses();') + for i in range(len(self.shouldQueries)+len(self.mustNotQueries)): + w('final Query q%d = clauses[%d].getQuery();' % (1+i, i)) + if c.doTrackScores or c.doMaxScore: + for i, q in enumerate(self.shouldQueries): + w('final Weight %s = ((BooleanQuery.BooleanWeight) %s).subWeight(%d);' % \ + (q.var('weight'), self.var('weight'), i)) + for i in range(len(self.shouldQueries)): + self.shouldQueries[i].writeTop(c, 'q%d' % (1+i)) + for i in range(len(self.mustNotQueries)): + self.mustNotQueries[i].writeTop(c, 'q%d' % (1+i+len(self.shouldQueries)), scores=False) + if c.doTrackScores or c.doMaxScore: + w('final float coordFactors[] = new float[%d];' % (len(self.shouldQueries)+1)) + for i in range(len(self.shouldQueries)+1): + w('coordFactors[%d] = sim.coord(%d, %d);' % (i, i, len(self.shouldQueries))) + + def writePerReader(self, c): + for q in self.shouldQueries: + q.writePerReader(c) + for q in self.mustNotQueries: + q.writePerReader(c, scores=False) + self.w('int limit = 0;') + for q in self.shouldQueries: + q.next() + for q in self.mustNotQueries: + q.next(scores=False) + + CHUNK = 512 + + def writeOneSubCollect(self, q): + w = self.w + mustNot = q in self.mustNotQueries + + w('while(%s < limit) {' % q.docVar) + w(' final int spot = %s&%s;' % (q.docVar, self.CHUNK-1)) + + if self.c.doMaxScore and not mustNot: + # TODO: pass in b.score as scorevar + q.writeScore() + if not self.s.doBucketArray: + w(' final Bucket b = buckets[spot];') + w(' if (b.doc != %s) {' % q.docVar) + w(' b.doc = %s;' % q.docVar) + if mustNot: + w(' b.reject = true;') + elif len(self.mustNotQueries) > 0: + w(' b.reject = false;') + else: + w(' if (bucketDocs[spot] != %s) {' % q.docVar) + w(' bucketDocs[spot] = %s;' % q.docVar) + if mustNot: + w(' bucketReject[spot] = true;') + elif len(self.mustNotQueries) > 0: + w(' bucketReject[spot] = false;') + + w(' usedBuckets[usedCount++] = spot;') + if not mustNot: + if self.c.doMaxScore: + if not self.s.doBucketArray: + w(' b.score = %s;' % q.scoreVar) + else: + w(' bucketScores[spot] = %s;' % q.scoreVar) + elif self.c.doTrackScores: + for qx in self.shouldQueries: + if q == qx: + qx.writeSaveScoreData('b') + else: + qx.resetScoreData('b') + if self.c.doTrackScores or self.c.doMaxScore or self.minNR != 0: + if not self.s.doBucketArray: + w(' b.coord = 1;') + else: + w(' bucketCoords[spot] = 1;') + if self.c.doTrackScores or self.c.doMaxScore or self.minNR != 0: + if len(self.mustNotQueries) > 0 and not mustNot: + if not self.s.doBucketArray: + w(' } else if (!b.reject) {') + else: + w(' } else if (!bucketReject[spot]) {') + else: + w(' } else {') + if mustNot: + if not self.s.doBucketArray: + w(' b.reject = true;') + else: + w(' bucketReject[spot] = true;') + else: + if not self.s.doBucketArray: + w(' b.coord++;') + else: + w(' bucketCoords[spot]++;') + if self.c.doMaxScore: + if not self.s.doBucketArray: + w(' b.score += %s;' % q.scoreVar) + else: + w(' bucketScores[spot] += %s;' % q.scoreVar) + elif self.c.doTrackScores: + q.writeSaveScoreData('b') + elif mustNot: + w(' } else {') + if not self.s.doBucketArray: + w(' b.reject = true;') + else: + w(' bucketReject[spot] = true;') + w(' }') + q.next(scores=not mustNot) + w('}') + + def writeTopIter(self, c): + w = self.w + w('while(limit < maxDoc) { // loop for all hits') + w(' limit += %d;' % self.CHUNK) + w(' if (limit > maxDoc) {') + w(' limit = maxDoc;') + w(' }') + w('int usedCount=0;') + l = list(enumerate(self.shouldQueries)) + l.reverse() + for i, q in l: + w('// loop for q%d chunk' % (1+i)) + self.writeOneSubCollect(q) + for i, q in enumerate(self.mustNotQueries): + w('// loop for q%d chunk' % (1+i+len(self.shouldQueries))) + self.writeOneSubCollect(q) + w(' // now collect') + w(' while(usedCount-- != 0) {') + if not self.s.doBucketArray: + w(' final Bucket b = buckets[usedBuckets[usedCount]];') + l = [] + if len(self.mustNotQueries) > 0: + l.append('b.reject') + if self.minNR != 0: + l.append('b.coord < %d' % self.minNR) + if len(l) > 0: + w(' if (%s) {' % (' || '.join(l))) + w(' continue;') + w(' }') + w(' final int doc = b.doc;') + else: + w(' final int spot = usedBuckets[usedCount];') + l = [] + if len(self.mustNotQueries) > 0: + l.append('bucketReject[spot]') + if self.minNR != 0: + l.append('bucketCoords[spot] < %d' % self.minNR) + if len(l) > 0: + w(' if (%s) {' % (' || '.join(l))) + w(' continue;') + w(' }') + w(' final int doc = bucketDocs[spot];') + c.collectOne(self, False, False) + w(' }') + w(' if (limit == r.maxDoc()) {') + w(' break;') + w(' }') + w('}') + +class TermQuery(Query): + + SCORE_CACHE_SIZE = 32 + BLOCK_SIZE = 32 + + didScore = False + + def setWriter(self, w): + self.w = w + self.id = w.getID() + + def getSubQueries(self): + return () + + def writeEntryClassData(self, asArray=0): + if asArray == 0: + self.w('int %s;' % self.var('freq')) + else: + self.w('final private static int[] bucket%s = new int[%d];' % (self.var('Freq'), asArray)) + + def writeTopClass(self, c): + pass + + def releaseVar(self, v): + self.w.releaseVar(v) + + didScoreData = None + def writeSaveScoreData(self, dest): + if not self.s.doBucketArray: + self.w('%s.%s = %s;' % (dest, self.var('freq'), self.var('freq'))) + else: + self.w('bucket%s[spot] = %s;' % (self.var('Freq'), self.var('freq'))) + self.didScoreData = 'yes' + + def resetScoreData(self, dest): + if not self.s.doBucketArray: + self.w('%s.%s = -1;' % (dest, self.var('freq'))) + else: + self.w('bucket%s[spot] = -1;' % self.var('Freq')) + + def writeScore(self): + w = self.w + if not self.didScore: + self.didScore = True + if self.didScoreData is not None: + freqVar = self.var('freqA') + if not self.s.doBucketArray: + w('if (b.%s != -1) {' % self.var('freq')) + w(' final int %s = b.%s;' % (self.var('freqA'), self.var('freq'))) + else: + w('final int %s = bucket%s[spot];' % (self.var('freqA'), self.var('Freq'))) + w('if (%s != -1) {' % self.var('freqA')) + docVar = 'doc' + else: + freqVar = self.var('freq') + docVar = self.docVar + + if self.s.doNorms: + w(' final float %s = (%s < %d ? %s[%s] : sim.tf(%s)*%s) * normDecoder[%s[%s] & 0xFF];' % \ + (self.scoreVar, + freqVar, + self.SCORE_CACHE_SIZE, + self.var('scoreCache'), + freqVar, + freqVar, + self.var('weightValue'), + self.var('norms'), + docVar)) + else: + w(' final float %s = (%s < %d ? %s[%s] : sim.tf(%s)*%s);' % \ + (self.scoreVar, + freqVar, + self.SCORE_CACHE_SIZE, + self.var('scoreCache'), + freqVar, + freqVar, + self.var('weightValue'))) + + if self.didScoreData is not None: + self.w(' score += %s;' % self.scoreVar) + self.w('}') + + def writeTop(self, c, qVar, scores=True): + w = self.w + w('final Term %s = ((TermQuery) %s).getTerm();' % (self.var('t'), qVar)) + + if scores and c.needsScores(): + w('final float[] %s = new float[%s];' % (self.var('scoreCache'), self.SCORE_CACHE_SIZE)) + w('final float %s = %s.getValue();' % \ + (self.var('weightValue'), self.var('weight'))) + w('for(int i=0;i<%s;i++) {' % self.SCORE_CACHE_SIZE) + w(' %s[i] = sim.tf(i) * %s;' % (self.var('scoreCache'), self.var('weightValue'))) + w('}') + + def writePerReader(self, c, scores=True): + w = self.w + if scores and c.needsScores() and self.s.doNorms: + w('final byte[] %s = r.norms(%s.field());' % (self.var('norms'), self.var('t'))) + w('int %s = 0;' % self.docVar) + if scores: + w('int %s = 0;' % self.var('freq')) + w('// only used to get the raw freqStream & limit') + w('final TermDocs %s = r.termDocs(%s);' % (self.var('TD'), self.var('t'))) + w('final IndexInput %s = ((SegmentTermDocs) %s).getFreqStream();' % (self.var('freqStream'), + self.var('TD'))) + w('final int %s = ((SegmentTermDocs ) %s).getTermFreq();' % (self.var('limit'), + self.var('TD'))) + w('int %s = 0;' % self.var('count')) + + def writeTopIter(q, c): + w = q.w + w('while(true) { // until we are done collecting hits from this reader') + q.next(True) + c.collectOne(q, checkEndDoc=False) + w('}') + + def next(self, isTop=False, scores=True): + w = self.w + if not isTop and self.s.rejectDoc is not None: + w('while(true) { // until we find a non-deleted & non-filtered-out doc') + w('if (++%s > %s) {' % (self.var('count'), self.var('limit'))) + if not isTop: + w(' %s = %s;' % (self.docVar, END_DOC)) + if self.s.rejectDoc is not None: + w(' break;') + else: + w(' break;') + w('} else {') + w(' final int %s = %s.readVInt();' % (self.var('x'), self.var('freqStream'))) + w(' %s += %s>>>1;' % (self.docVar, self.var('x'))) + if self.s.rejectDoc is not None: + w(' // fast reject') + w(' if (%s) {' % (self.s.rejectDoc.replace('__DOC__', self.docVar))) + w(' if ((%s & 1) == 0) {' % self.var('x')) + w(' // Skip freq') + w(' %s.readVInt();' % self.var('freqStream')) + w(' }') + w(' continue;') + w(' } else {') + if scores: + w(' if ((%s&1) != 0) {' % self.var('x')) + w(' %s = 1;' % self.var('freq')) + w(' } else {') + w(' %s = %s.readVInt();' % (self.var('freq'), self.var('freqStream'))) + w(' }') + else: + w(' if ((%s&1) == 0) {' % self.var('x')) + w(' // skip freq') + w(' %s.readVInt();' % self.var('freqStream')) + w(' }') + if not isTop: + w(' break;') + w(' }') + else: + if scores: + w(' if ((%s&1) != 0) {' % self.var('x')) + w(' %s = 1;' % self.var('freq')) + w(' } else {') + w(' %s = %s.readVInt();' % (self.var('freq'), self.var('freqStream'))) + w(' }') + else: + w(' if ((%s&1) == 0) {' % self.var('x')) + w(' // skip freq') + w(' %s.readVInt();' % self.var('freqStream')) + w(' }') + + w(' }') + if not isTop and self.s.rejectDoc is not None: + w(' }') + +DEBUG = '-verbose' in sys.argv + +def writeTopIter(q, c): + w = q.w + w('while(true) { // until we are done collecting hits from this reader') + q.next() + c.collectOne(q) + w('}') + +class State: + pass + +def gen(w, query, c, doFilter, doDeletes, doNorms, doTotalHits, fileName): + + s = State() + s.doFilter = doFilter + s.doDeletes = doDeletes + if s.doFilter: + s.rejectDoc = '!filterBits.fastGet(__DOC__)' + elif s.doDeletes: + s.rejectDoc = 'deletedDocs != null && deletedDocs.get(__DOC__)' + else: + s.rejectDoc = None + s.doNorms = doNorms + s.doTotalHits = doTotalHits + s.doBucketArray = True + + query.pushState(s) + query.setWriter(w) + query.setVars('doc', 'score') + + c.setState(s) + + w('package org.apache.lucene.search;') + + w('import org.apache.lucene.util.*;') + w('import org.apache.lucene.store.*;') + w('import org.apache.lucene.search.FieldCache.StringIndex;') + w('import org.apache.lucene.index.*;') + w('import java.io.IOException;') + w('import java.util.Arrays;') + w('import java.util.HashSet;') + + className = os.path.splitext(os.path.split(fileName)[1])[0] + + w('final class %s extends SpecSearch {' % className) + w(' final private static SpecSearch instance = new %s();' % className); + w(' SpecSearch getInstance() {') + w(' return instance;') + w(' }') + + c.writeTopClass() + query.writeTopClass(c) + + w(' public TopDocs search(final IndexSearcher searcher, final Query q, final Filter filter, final Sort sort, final int topN) throws IOException {') + + c.topInit() + + w(' final float[] normDecoder = Similarity.getNormDecoder();') + w(' final Similarity sim = searcher.getSimilarity();') + w(' final IndexReader[] subReaders = searcher.getIndexReader().getSequentialSubReaders();') + w(' final int[] docBases = new int[subReaders.length];') + w(' {') + w(' int docBase = 0;') + w(' for(int rx=0;rx compiletop.log 2>&1') != 0: + raise RuntimeError('compile failed (see compiletop.log)') + +if 1: + print 'genall...' + import genall + genall.main() + +os.chdir('../benchmark') + +if 1: + if '-nocompile' not in sys.argv: + print 'compile...' + if os.system('ant compile > compile.log 2>&1') != 0: + print open('compile.log').read() + raise RuntimeError('compile failed (see compile.log)') + +if windows: + RESULTS = 'results.win64' +else: + RESULTS = 'results' + +VERIFY = '-verify' in sys.argv + +ALG = ''' +analyzer=org.apache.lucene.analysis.standard.StandardAnalyzer +directory=FSDirectory +work.dir = $INDEX$ +search.num.hits = 10 +query.maker=org.apache.lucene.benchmark.byTask.feeds.FileBasedQueryMaker +file.query.maker.file = queries.txt +log.queries=true +filter.pct = $FILT$ +search.spec = $SEARCH_SPEC$ +total.hits = $TOTAL_HITS$ + +OpenReader +{"XSearchWarm" $SEARCH$} +$ROUNDS$ +CloseReader +RepSumByPrefRound XSearch +''' + +if os.path.exists('searchlogs'): + shutil.rmtree('searchlogs') + +os.makedirs('searchlogs') + +open('%s.txt' % RESULTS, 'wb').write('||Query||Sort||Filt|Deletes||Scoring||Hits||QPS (base)||QPS (new)||%||\n') + +numHit = 10 +counter = 0 + +if '-delindex' in sys.argv: + i = sys.argv.index('-delindex') + DEL_INDEX = sys.argv[1+i] + del sys.argv[i:i+2] + if False and not os.path.exists(DEL_INDEX): + raise RuntimeError('index "%s" does not exist' % DEL_INDEX) +else: + DEL_INDEX = None + +if '-nodelindex' in sys.argv: + i = sys.argv.index('-nodelindex') + NO_DEL_INDEX = sys.argv[1+i] + del sys.argv[i:i+2] + if False and not os.path.exists(NO_DEL_INDEX): + raise RuntimeError('index "%s" does not exist' % NO_DEL_INDEX) +else: + NO_DEL_INDEX = None + +if DEL_INDEX is None and NO_DEL_INDEX is None: + raise RuntimeError('you must specify at least one of -delindex or -nodelindex') + +def boolToYesNo(b): + if b: + return 'true' + else: + return 'false' + +def run(new, query, sortField, doScore, filt, delP, doTotalHits): + global counter + + t0 = time.time() + + s = ALG + + if not VERIFY: + s = s.replace('$ROUNDS$', +''' +{ "Rounds" + { "Run" + { "TestSearchSpeed" + { "XSearchReal" $SEARCH$ > : 3.0s + } + NewRound + } : $NROUND$ +} +''') + nround = 5 + s = s.replace('$NROUND$', str(nround)) + else: + s = s.replace('$ROUNDS$', '') + + if linux: + prefix = '/big/scratch/lucene' + else: + prefix = '/lucene' + + if delP is None: + index = NO_DEL_INDEX + else: + index = DEL_INDEX + + s = s.replace('$INDEX$', index) + + if doTotalHits: + v = 'true' + else: + v = 'false' + s = s.replace('$TOTAL_HITS$', v) + + open('queries.txt', 'wb').write(query + '\n') + + if filt == None: + f = '0.0' + else: + f = str(filt) + s = s.replace('$FILT$', f) + + if doScore == 'both': + doTrackScores = True + doMaxScore = True + elif doScore == 'track': + doTrackScores = True + doMaxScore = False + elif doScore == 'no': + doTrackScores = False + doMaxScore = False + + s = s.replace('$SEARCH_SPEC$', boolToYesNo(new)) + + l = [] + if not doMaxScore: + l.append('nomaxscore') + if not doTrackScores: + l.append('noscore') + + if len(l) > 0: + sv = ',%s' % (','.join(l)) + else: + sv = '' + + if sortField == 'score': + search = 'Search' + elif sortField == 'doctitle': + search = 'SearchWithSort(doctitle:string%s)' % sv + elif sortField == 'docdate': + search = 'SearchWithSort(docdate:long%s)' % sv + else: + raise RuntimeError("no") + + s = s.replace('$SEARCH$', search) + fileOut = 'searchlogs/%d' % counter + counter += 1 + + if 0: + fileOut = 'searchlogs/query_%s_%s_%s_%s_%s_%s' % \ + (query.replace(' ', '_').replace('"', ''), + sortField, filt, doScore, new, delP) + + open('tmp.alg', 'wb').write(s) + + if windows: + command = 'java -Xms1024M -Xmx1024M -Xbatch -server -cp "../../build/classes/java;../../build/classes/demo;../../build/contrib/highlighter/classes/java;../../build/contrib/spec/classes/java;../../contrib/benchmark/lib/commons-digester-1.7.jar;../../contrib/benchmark/lib/commons-collections-3.1.jar;../../contrib/benchmark/lib/commons-logging-1.0.4.jar;../../contrib/benchmark/lib/commons-beanutils-1.7.0.jar;../../contrib/benchmark/lib/xerces-2.9.0.jar;../../contrib/benchmark/lib/xml-apis-2.9.0.jar;../../build/contrib/benchmark/classes/java" org.apache.lucene.benchmark.byTask.Benchmark tmp.alg > %s' % fileOut + else: + command = 'java -Xms1024M -Xmx1024M -Xbatch -server -cp ../../build/classes/java:../../build/classes/demo:../../build/contrib/highlighter/classes/java:../../build/contrib/spec/classes/java:../../contrib/benchmark/lib/commons-digester-1.7.jar:../../contrib/benchmark/lib/commons-collections-3.1.jar:../../contrib/benchmark/lib/commons-logging-1.0.4.jar:../../contrib/benchmark/lib/commons-beanutils-1.7.0.jar:../../contrib/benchmark/lib/xerces-2.9.0.jar:../../contrib/benchmark/lib/xml-apis-2.9.0.jar:../../build/contrib/benchmark/classes/java org.apache.lucene.benchmark.byTask.Benchmark tmp.alg > %s' % fileOut + + print ' %s' % fileOut + + res = os.system(command) + + if res != 0: + raise RuntimeError('FAILED') + + best = None + count = 0 + nhits = None + warmTime = None + meths = [] + r = re.compile('^ ([0-9]+): (.*)$') + topN = [] + sawSpec = False + maxScore = None + spec = None + + for line in open(fileOut, 'rb').readlines(): + m = r.match(line.rstrip()) + if m is not None: + topN.append(m.group(2)) + if line.find('SPEC search') != -1: + sawSpec = True + spec = line[12:].strip() + if line.startswith('NUMHITS='): + nhits = int(line[8:].strip()) + if line.startswith('MAXSCORE='): + maxScore = line[9:].strip() + if line.startswith('XSearchWarm'): + v = line.strip().split() + warmTime = float(v[5]) + if line.startswith('XSearchReal'): + v = line.strip().split() + # print len(v), v + upto = 0 + i = 0 + qps = None + while i < len(v): + if v[i] == '-': + i += 1 + continue + else: + upto += 1 + i += 1 + if upto == 5: + #print 'GOT: %s' % v[i-1] + qps = float(v[i-1].replace(',', '')) + break + + if qps is None: + raise RuntimeError('did not find qps') + + count += 1 + if best is None or qps > best: + best = qps + + if sawSpec != new: + raise RuntimeError('spec did not kick in properly') + + if nhits is None: + raise RuntimeError('did not see NUMHITS= line') + + if maxScore is None: + raise RuntimeError('did not see MAXSCORE= line') + + if not VERIFY: + if count != nround: + raise RuntimeError('did not find %s rounds (got %s)' % (nround, count)) + + if warmTime is None: + raise RuntimeError('did not find warm time') + + # print ' NHIT: %s' % nhits + + # print ' %.1f qps; %.1f sec' % (best, time.time()-t0) + all.append((new, query, sortBy, filt, nhits, warmTime, best)) + else: + best = 1.0 + + return nhits, best, topN, maxScore, spec + +def cleanScores(l): + for i in range(len(l)): + pos = l[i].find(' score=') + l[i] = l[i][:pos].strip() + +rx = re.compile('doc=(\d+) score=(.*?)$') +rx2 = re.compile('doc=(\d+) v=(.*?) score=(.*?)$') + +def parse(topN, hasValues): + l = [] + for e in topN: + if hasValues: + m = rx2.search(e) + l.append((int(m.group(1)), m.group(2), m.group(3))) + else: + m = rx.search(e) + l.append((int(m.group(1)), m.group(2))) + return l + +def equals(topN1, topN2, hasValues): + if len(topN1) != len(topN2): + return False + top1 = parse(topN1, hasValues) + top2 = parse(topN2, hasValues) + for i in range(len(top1)): + if not hasValues: + doc1, score1 = top1[i] + doc2, score2 = top2[i] + else: + doc1, v1, score1 = top1[i] + doc2, v2, score2 = top2[i] + if doc1 != doc2: + return False + if score1 == 'Nan' and score2 == 'Nan': + pass + elif score1 == 'Nan' or score2 == 'Nan': + return False + else: + score1 = float(score1) + score2 = float(score2) + if abs(score1-score2) > 0.000001: + return False + if hasValues and v1 != v2: + return False + + return True + +all = [] + +if VERIFY: + filts = (10.0, None, 25.0) +else: + filts = (None, 25.0) + +if VERIFY: + queries = ['1 OR 2 OR 3', '1 OR -2', '1 OR 2', '1'] + #queries = ['1 OR 2 OR 3'] +else: + queries = ['1 OR 2 -3'] + +if VERIFY: + sort = ('score', 'doctitle', 'docdate') +else: + sort = ('doctitle',) + +deletes = (None, 5) + +if VERIFY: + doScores = ('both', 'track', 'no') +else: + doScores = ('no', 'track') + +for query in queries: + for sortBy in sort: + + if sortBy == 'score': + doScores0 = ('both',) + else: + doScores0 = doScores + + for doScore in doScores0: + for filt in filts: + for delP in deletes: + + doTotalHits = False + + if delP is None and NO_DEL_INDEX is None: + continue + if delP is not None and DEL_INDEX is None: + continue + + print + print 'RUN: query=%s sort=%s scores=%s filt=%s deletes=%s [%s]' % (query, sortBy, doScore, filt, delP, datetime.datetime.now()) + print ' new...' + nhits1, qps1, topN1, maxScore1, spec = run(True, query, sortBy, doScore, filt, delP, doTotalHits) + print ' qps %.2f' % qps1 + print ' spec = src/java/org/apache/lucene/search/%s.java' % spec[25:] + print ' old...' + nhits2, qps2, topN2, maxScore2, ign = run(False, query, sortBy, doScore, filt, delP, doTotalHits) + print ' qps %.2f' % qps2 + if doTotalHits: + print ' %d hits' % nhits1 + print ' %.1f%%' % (100.*(qps1-qps2)/qps2) + + f = open('%s.pk' % RESULTS, 'wb') + cPickle.dump(all, f) + f.close() + + if nhits1 != nhits2 and doTotalHits: + raise RuntimeError('hits differ: %s vs %s' % (nhits1, nhits2)) + + if maxScore1 is None or maxScore2 is None: + if maxScore1 != maxScore2: + raise RuntimeError('maxScore differ: %s vs %s' % (maxScore1, maxScore2)) + elif abs(float(maxScore1)-float(maxScore2)) > 0.00001: + raise RuntimeError('maxScore differ: %s vs %s' % (maxScore1, maxScore2)) + + if len(topN1) != numHit: + raise RuntimeError('not enough hits: %s vs %s' % (len(topN1), nhits1)) + + if not equals(topN1, topN2, sortBy != 'score'): + raise RuntimeError('results differ') + + if sortBy == 'score': + s0 = 'Relevance' + elif sortBy == 'doctitle': + s0 = 'Title (string)' + elif sortBy == 'docdate': + s0 = 'Date (long)' + + if filt == None: + f = 'no' + else: + f = '%d%%' % filt + + if delP == None: + d = 'no' + else: + d = '%d%%' % delP + + if doScore == 'both': + s = 'Track,Max' + elif doScore == 'track': + s = 'Track' + else: + s = 'no' + + pct = (qps1-qps2)/qps2 + if pct <= 0.0: + color = 'red' + else: + color = 'green' + p = '{color:%s}%.1f%%{color}' % (color, 100.*pct) + + open('%s.txt' % RESULTS, 'ab').write('|%s|%s|%s|%s|%s|%d|%.1f|%.1f|%s|\n' % \ + (query, s0, f, d, s, nhits1, qps2, qps1, p)) Index: verify.cmd =================================================================== --- verify.cmd (revision 0) +++ verify.cmd (revision 0) @@ -0,0 +1 @@ +python -u bench.py -delindex /lucene/work.wikifull.8seg.5pdel -nodelindex /lucene/work.wikifull.8seg -verify Index: NOTES.txt =================================================================== --- NOTES.txt (revision 0) +++ NOTES.txt (revision 0) @@ -0,0 +1,44 @@ + +To run this: + + * cd contrib/spec + + * python -u genall.py + + * ant compile + +Then change your code like this: + + import org.apache.lucene.search.FastSearch; + + final private FastSearch fastSearch = new FastSearch(); + + if (sort == null) { + hits = fastSearch.search(searcher, q, filter, numHits, doTotalHits); + } else { + hits = fastSearch.search(searcher, q, filter, sort, withScore(), withMaxScore(), true, numHits, doTotalHits); + } + +That call will simply pass through to IndexSearcher if there's no +specialized class. + +Caveats/limitations: + + * This code is NOT thread safe; it uses pre-initialized static state + + * If you pass a filter in, it must 1) return OpenDocIdSet from + getDocIdSet(), and 2) that bit set must have already "folded in" + deletes + + * When you sort by String, this code will silently replace any null + values with a sentinel value (U+0000) + + * If you sort by field, no docs may have the "sentinel" value (eg if + you sort by long, Long.MAX_VALUE). + + * It can only specialize single-field, or by score topN collection + + * It cannot do reversed sort + + * It can only handle single TermQuery, or N-clause OR of TermQuery + (with MUST_NOT clauses and with minimumNumberShouldMatch). Index: src/java/org/apache/lucene/search/SpecSearch.java =================================================================== --- src/java/org/apache/lucene/search/SpecSearch.java (revision 0) +++ src/java/org/apache/lucene/search/SpecSearch.java (revision 0) @@ -0,0 +1,25 @@ +package org.apache.lucene.search; + +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import java.io.IOException; + +abstract class SpecSearch { + abstract SpecSearch getInstance(); + abstract TopDocs search(IndexSearcher searcher, Query query, Filter filter, Sort sort, final int topN) throws IOException; +} \ No newline at end of file Index: src/java/org/apache/lucene/search/FastSearch.java =================================================================== --- src/java/org/apache/lucene/search/FastSearch.java (revision 0) +++ src/java/org/apache/lucene/search/FastSearch.java (revision 0) @@ -0,0 +1,307 @@ +package org.apache.lucene.search; + +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import org.apache.lucene.util.OpenBitSet; +import org.apache.lucene.index.IndexReader; +import org.apache.lucene.index.MultiSegmentReader; +import java.io.IOException; +import java.util.List; +import java.util.ArrayList; +import java.util.Iterator; + +/** + * If you pass a non-null filter, you must ensure: + * + * 1) Its getDocIdSet method always returns an OpenBitSet + * + * 2) That returned OpenBitSet has pre-multiplied in + * any doc deletions + */ + +public class FastSearch { + + public TopDocs search(IndexSearcher searcher, + Query query, + Filter filter, + int topN, + boolean doTotalHits) throws IOException { + return _search(searcher, query, filter, null, false, false, false, topN, doTotalHits); + } + + + /** When possible, uses a pre-compiled specialized class + * to execute the specific sort; else, simply falls + * through to the "normal" corresponding method in + * IndexSearcher */ + public TopFieldDocs search(IndexSearcher searcher, + Query query, + Filter filter, + Sort sort, + boolean doTrackScores, + boolean doMaxScore, + boolean fillFields, + int topN, + boolean doTotalHits) throws IOException { + + return (TopFieldDocs) _search(searcher, query, filter, sort, doTrackScores, doMaxScore, fillFields, topN, doTotalHits); + } + + static private boolean first = true; + private TopDocs _search(IndexSearcher searcher, + Query query, + Filter filter, + Sort sort, + boolean doTrackScores, + boolean doMaxScore, + boolean fillFields, + int topN, + boolean doTotalHits) throws IOException { + + query = searcher.rewrite(query); + + final String specClassName = getSpecClassName(searcher, query, filter, sort, doTrackScores, doMaxScore, topN, doTotalHits); + + if (specClassName != null) { + // specializer can handle this query: load the class + Class c; + + try { + c = Class.forName(specClassName); + } catch (ClassNotFoundException e) { + throw new IOException("unable to find specialized search class " + specClassName); + } + + SpecSearch specSearch; + try { + specSearch = (SpecSearch) c.newInstance(); + } catch (IllegalAccessException e) { + throw new IOException("IllegalAccessException when instantiating SpecClass " + specClassName); + } catch (InstantiationException e) { + throw new IOException("InstantiationException when instantiating SpecClass " + specClassName); + } catch (ClassCastException e) { + throw new IOException("unable to cast SpecClass " + specClassName + " instance to a SpecSearch"); + } + + specSearch = specSearch.getInstance(); + + if (first) { + System.out.println("SPEC search=" + specClassName); + first = false; + } + + return specSearch.search(searcher, query, filter, sort, topN); + + } else { + // fallback + if (sort != null) { + final TopFieldCollector collector = TopFieldCollector.create(sort, topN, fillFields, doTrackScores, doMaxScore); + searcher.search(query, filter, collector); + return collector.topDocs(); + } else { + return searcher.search(query, filter, topN, sort); + } + } + } + + public String getSpecClassName(IndexSearcher searcher, + Query query, + Filter filter, + Sort sort, + boolean doTrackScores, + boolean doMaxScore, + int topN, + boolean doTotalHits) throws IOException { + + final IndexReader r = searcher.getIndexReader(); + if (r instanceof MultiSegmentReader) { + String sortDesc; + if (sort == null) { + sortDesc = "Score"; + doTrackScores = true; + doMaxScore = true; + } else { + SortField[] sortFields = sort.getSort(); + if (sortFields.length == 1) { + if (!sortFields[0].getReverse()) { + switch(sortFields[0].getType()) { + case SortField.SCORE: + sortDesc = "Score"; + break; + case SortField.DOC: + sortDesc = null; + break; + case SortField.AUTO: + sortDesc = null; + break; + case SortField.STRING: + sortDesc = "String"; + break; + case SortField.INT: + sortDesc = "Int"; + break; + case SortField.FLOAT: + sortDesc = "Float"; + break; + case SortField.LONG: + sortDesc = "Long"; + break; + case SortField.DOUBLE: + sortDesc = "Double"; + break; + case SortField.SHORT: + sortDesc = "Short"; + break; + case SortField.CUSTOM: + sortDesc = null; + break; + case SortField.BYTE: + sortDesc = "Byte"; + break; + default: + sortDesc = null; + break; + } + } else { + sortDesc = null; + } + } else { + sortDesc = null; + } + } + + String trackScoresDesc; + if (doTrackScores) { + trackScoresDesc = "Yes"; + } else { + trackScoresDesc = "No"; + } + + String maxScoreDesc; + if (doMaxScore) { + maxScoreDesc = "Yes"; + } else { + maxScoreDesc = "No"; + } + + String queryDesc; + String field = null; + + if (query instanceof TermQuery) { + queryDesc = "Term"; + field = ((TermQuery) query).getTerm().field(); + } else if (query instanceof BooleanQuery) { + BooleanClause[] clauses = ((BooleanQuery) query).getClauses(); + final int clauseCount = clauses.length; + boolean fail = false; + List shouldClauses = new ArrayList(); + List mustNotClauses = new ArrayList(); + for(int i=0;i shouldClauses.size()) { + minNR = shouldClauses.size(); + } + queryDesc = shouldClauses.size() + "Should" + mustNotClauses.size() + "MustNot_MinNr" + minNR; + // should clauses must all be first + BooleanQuery bq = new BooleanQuery(); + Iterator it = shouldClauses.iterator(); + while(it.hasNext()) { + bq.add((BooleanClause) it.next()); + } + it = mustNotClauses.iterator(); + while(it.hasNext()) { + bq.add((BooleanClause) it.next()); + } + query = bq; + } else { + queryDesc = null; + field = null; + } + } else { + field = null; + queryDesc = null; + } + + String filterDesc; + if (filter == null) { + filterDesc = "No"; + } else { + filterDesc = "Yes"; + } + + String delDesc; + if (r.hasDeletions()) { + delDesc = "Yes"; + } else { + delDesc = "No"; + } + + if (filter != null) { + // Filter must have pre-multiplied deletes + delDesc = "No"; + } + + String totalHitsDesc; + if (doTotalHits) { + totalHitsDesc = "Yes"; + } else { + totalHitsDesc = "No"; + } + + String normsDesc; + if (!r.hasNorms(field)) { + normsDesc = "No"; + } else { + normsDesc = "Yes"; + } + + if (sortDesc != null && filterDesc != null && queryDesc != null) { + return "org.apache.lucene.search.Spec_Query" + queryDesc + "_Sort" + sortDesc + "_Filter" + filterDesc + "_Deletes" + delDesc + "_TrackScores" + trackScoresDesc + "_MaxScore" + maxScoreDesc + "_Norms" + normsDesc + "_TotalHits" + totalHitsDesc; + } else { + return null; + } + } else { + return null; + } + } +} \ No newline at end of file Index: genall.py =================================================================== --- genall.py (revision 0) +++ genall.py (revision 0) @@ -0,0 +1,81 @@ +import os +import shutil +import gen + +numHit = 10 + +# queries = ('Term', '2ClauseOr', '3ClauseOr') +queries = [('or', 3, 0, 0)] + +def main(): + + os.system('rm -f src/java/org/apache/lucene/search/Spec_Query*.java') + + for query in queries: + for filterDesc in ('No', 'Yes'): + if filterDesc == 'Yes': + delDescs = ('No',) + else: + delDescs = ('No', 'Yes') + for delDesc in delDescs: + for sortDesc in ('Score', 'String', 'Byte', 'Short', 'Int', 'Long', 'Float', 'Double'): + + if sortDesc != 'Score': + maxScoreDescs = ('Yes', 'No') + trackScoresDescs = ('Yes', 'No') + else: + maxScoreDescs = ('Yes',) + trackScoresDescs = ('Yes',) + + for maxScoreDesc in maxScoreDescs: + for trackScoresDesc in trackScoresDescs: + for norms in ('Yes', 'No'): + for doTotalHits in ('Yes', 'No'): + + if query[0] == 'or': + numShould = query[1] + numMustNot = query[2] + minNR = query[3] + queryDesc = '%dShould%dMustNot_MinNr%s' % (numShould, numMustNot, minNR) + else: + queryDesc = query + minNR = None + + fileName = 'src/java/org/apache/lucene/search/Spec_Query%s_Sort%s_Filter%s_Deletes%s_TrackScores%s_MaxScore%s_Norms%s_TotalHits%s.java' % \ + (queryDesc, sortDesc, filterDesc, delDesc, trackScoresDesc, maxScoreDesc, norms, doTotalHits) + + w = gen.Writer() + + if sortDesc == 'Score': + c = gen.ScoreDocCollector(w) + else: + c = gen.SortByOneFieldCollector(w, trackScoresDesc=='Yes', maxScoreDesc=='Yes') + + if sortDesc == 'String': + c.comp = gen.StringOrdValComparator(w, replaceNulls=True) + c.valueType = None + else: + c.comp = None + c.valueType = sortDesc.lower() + + if queryDesc == 'Term': + q = gen.TermQuery() + else: + + qShould = [] + for i in range(numShould): + qShould.append(gen.TermQuery()) + qMustNot = [] + for i in range(numMustNot): + qMustNot.append(gen.TermQuery()) + q = gen.NClauseOrQuery(qShould, qMustNot, minNR=minNR) + + if False and fileName.find('Spec_Query3ClauseOr_SortString_FilterYes_DeletesNo_TrackScoresYes_MaxScoreYes_NormsYes_TotalHitsNo') == -1: + continue + + # print 'Generate %s...' % (fileName[fileName.rfind('/')+1:]) + gen.gen(w, q, c, filterDesc=='Yes', delDesc=='Yes', + norms=='Yes', doTotalHits=='Yes', fileName) + +if __name__ == '__main__': + main() Index: build.xml =================================================================== --- build.xml (revision 0) +++ build.xml (revision 0) @@ -0,0 +1,27 @@ + + + + + + + + Fast source-code specialized search + + + +