Index: lucene/classification/src/test/org/apache/lucene/classification/KNearestNeighborClassifierTest.java =================================================================== --- lucene/classification/src/test/org/apache/lucene/classification/KNearestNeighborClassifierTest.java (revision 1453398) +++ lucene/classification/src/test/org/apache/lucene/classification/KNearestNeighborClassifierTest.java (working copy) @@ -17,16 +17,17 @@ package org.apache.lucene.classification; import org.apache.lucene.analysis.MockAnalyzer; +import org.apache.lucene.util.BytesRef; import org.junit.Test; /** * Testcase for {@link KNearestNeighborClassifier} */ -public class KNearestNeighborClassifierTest extends ClassificationTestBase { +public class KNearestNeighborClassifierTest extends ClassificationTestBase { @Test public void testBasicUsage() throws Exception { - checkCorrectClassification(new KNearestNeighborClassifier(1), new MockAnalyzer(random())); + checkCorrectClassification(new KNearestNeighborClassifier(1), new BytesRef("technology"), new MockAnalyzer(random()), categoryFieldName); } } Index: lucene/classification/src/test/org/apache/lucene/classification/SimpleNaiveBayesClassifierTest.java =================================================================== --- lucene/classification/src/test/org/apache/lucene/classification/SimpleNaiveBayesClassifierTest.java (revision 1453398) +++ lucene/classification/src/test/org/apache/lucene/classification/SimpleNaiveBayesClassifierTest.java (working copy) @@ -19,6 +19,7 @@ import org.apache.lucene.analysis.Analyzer; import org.apache.lucene.analysis.MockAnalyzer; import org.apache.lucene.analysis.ngram.EdgeNGramTokenizer; +import org.apache.lucene.util.BytesRef; import org.apache.lucene.util.LuceneTestCase; import org.junit.Test; @@ -29,16 +30,16 @@ */ // TODO : eventually remove this if / when fallback methods exist for all un-supportable codec methods (see LUCENE-4872) @LuceneTestCase.SuppressCodecs("Lucene3x") -public class SimpleNaiveBayesClassifierTest extends ClassificationTestBase { +public class SimpleNaiveBayesClassifierTest extends ClassificationTestBase { @Test public void testBasicUsage() throws Exception { - checkCorrectClassification(new SimpleNaiveBayesClassifier(), new MockAnalyzer(random())); + checkCorrectClassification(new SimpleNaiveBayesClassifier(), new BytesRef("technology"), new MockAnalyzer(random()), categoryFieldName); } @Test public void testNGramUsage() throws Exception { - checkCorrectClassification(new SimpleNaiveBayesClassifier(), new NGramAnalyzer()); + checkCorrectClassification(new SimpleNaiveBayesClassifier(), new BytesRef("technology"), new NGramAnalyzer(), categoryFieldName); } private class NGramAnalyzer extends Analyzer { Index: lucene/classification/src/test/org/apache/lucene/classification/BooleanPerceptronClassifierTest.java =================================================================== --- lucene/classification/src/test/org/apache/lucene/classification/BooleanPerceptronClassifierTest.java (revision 0) +++ lucene/classification/src/test/org/apache/lucene/classification/BooleanPerceptronClassifierTest.java (revision 0) @@ -0,0 +1,32 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.lucene.classification; + +import org.apache.lucene.analysis.MockAnalyzer; +import org.junit.Test; + +/** + * Testcase for {@link BooleanPerceptronClassifier} + */ +public class BooleanPerceptronClassifierTest extends ClassificationTestBase { + + @Test + public void testBasicUsage() throws Exception { + checkCorrectClassification(new BooleanPerceptronClassifier(10d), Boolean.TRUE, new MockAnalyzer(random()), booleanFieldName); + } + +} Property changes on: lucene/classification/src/test/org/apache/lucene/classification/BooleanPerceptronClassifierTest.java ___________________________________________________________________ Added: svn:eol-style + native Index: lucene/classification/src/test/org/apache/lucene/classification/ClassificationTestBase.java =================================================================== --- lucene/classification/src/test/org/apache/lucene/classification/ClassificationTestBase.java (revision 1453398) +++ lucene/classification/src/test/org/apache/lucene/classification/ClassificationTestBase.java (working copy) @@ -24,7 +24,6 @@ import org.apache.lucene.index.RandomIndexWriter; import org.apache.lucene.index.SlowCompositeReaderWrapper; import org.apache.lucene.store.Directory; -import org.apache.lucene.util.BytesRef; import org.apache.lucene.util.LuceneTestCase; import org.junit.After; import org.junit.Before; @@ -32,12 +31,13 @@ /** * Base class for testing {@link Classifier}s */ -public abstract class ClassificationTestBase extends LuceneTestCase { +public abstract class ClassificationTestBase extends LuceneTestCase { private RandomIndexWriter indexWriter; private String textFieldName; - private String classFieldName; private Directory dir; + String categoryFieldName; + String booleanFieldName; @Override @Before @@ -46,7 +46,8 @@ dir = newDirectory(); indexWriter = new RandomIndexWriter(random(), dir); textFieldName = "text"; - classFieldName = "cat"; + categoryFieldName = "cat"; + booleanFieldName = "bool"; } @Override @@ -58,17 +59,17 @@ } - protected void checkCorrectClassification(Classifier classifier, Analyzer analyzer) throws Exception { + protected void checkCorrectClassification(Classifier classifier, T expectedResult, Analyzer analyzer, String classFieldName) throws Exception { SlowCompositeReaderWrapper compositeReaderWrapper = null; try { populateIndex(analyzer); compositeReaderWrapper = new SlowCompositeReaderWrapper(indexWriter.getReader()); classifier.train(compositeReaderWrapper, textFieldName, classFieldName, analyzer); String newText = "Much is made of what the likes of Facebook, Google and Apple know about users. Truth is, Amazon may know more."; - ClassificationResult classificationResult = classifier.assignClass(newText); + ClassificationResult classificationResult = classifier.assignClass(newText); assertNotNull(classificationResult.getAssignedClass()); - assertEquals(new BytesRef("technology"), classificationResult.getAssignedClass()); - assertTrue(classificationResult.getScore() > 0); + assertEquals("got an assigned class of " + classificationResult.getAssignedClass(), expectedResult, classificationResult.getAssignedClass()); + assertTrue("got a not positive score " + classificationResult.getScore(), classificationResult.getScore() > 0); } finally { if (compositeReaderWrapper != null) compositeReaderWrapper.close(); @@ -86,48 +87,55 @@ doc.add(new Field(textFieldName, "The traveling press secretary for Mitt Romney lost his cool and cursed at reporters " + "who attempted to ask questions of the Republican presidential candidate in a public plaza near the Tomb of " + "the Unknown Soldier in Warsaw Tuesday.", ft)); - doc.add(new Field(classFieldName, "politics", ft)); + doc.add(new Field(categoryFieldName, "politics", ft)); + doc.add(new Field(booleanFieldName, "false", ft)); indexWriter.addDocument(doc, analyzer); doc = new Document(); doc.add(new Field(textFieldName, "Mitt Romney seeks to assure Israel and Iran, as well as Jewish voters in the United" + " States, that he will be tougher against Iran's nuclear ambitions than President Barack Obama.", ft)); - doc.add(new Field(classFieldName, "politics", ft)); + doc.add(new Field(categoryFieldName, "politics", ft)); + doc.add(new Field(booleanFieldName, "false", ft)); indexWriter.addDocument(doc, analyzer); doc = new Document(); doc.add(new Field(textFieldName, "And there's a threshold question that he has to answer for the American people and " + "that's whether he is prepared to be commander-in-chief,\" she continued. \"As we look to the past events, we " + "know that this raises some questions about his preparedness and we'll see how the rest of his trip goes.\"", ft)); - doc.add(new Field(classFieldName, "politics", ft)); + doc.add(new Field(categoryFieldName, "politics", ft)); + doc.add(new Field(booleanFieldName, "false", ft)); indexWriter.addDocument(doc, analyzer); doc = new Document(); doc.add(new Field(textFieldName, "Still, when it comes to gun policy, many congressional Democrats have \"decided to " + "keep quiet and not go there,\" said Alan Lizotte, dean and professor at the State University of New York at " + "Albany's School of Criminal Justice.", ft)); - doc.add(new Field(classFieldName, "politics", ft)); + doc.add(new Field(categoryFieldName, "politics", ft)); + doc.add(new Field(booleanFieldName, "false", ft)); indexWriter.addDocument(doc, analyzer); doc = new Document(); doc.add(new Field(textFieldName, "Standing amongst the thousands of people at the state Capitol, Jorstad, director of " + "technology at the University of Wisconsin-La Crosse, documented the historic moment and shared it with the " + "world through the Internet.", ft)); - doc.add(new Field(classFieldName, "technology", ft)); + doc.add(new Field(categoryFieldName, "technology", ft)); + doc.add(new Field(booleanFieldName, "true", ft)); indexWriter.addDocument(doc, analyzer); doc = new Document(); doc.add(new Field(textFieldName, "So, about all those experts and analysts who've spent the past year or so saying " + "Facebook was going to make a phone. A new expert has stepped forward to say it's not going to happen.", ft)); - doc.add(new Field(classFieldName, "technology", ft)); + doc.add(new Field(categoryFieldName, "technology", ft)); + doc.add(new Field(booleanFieldName, "true", ft)); indexWriter.addDocument(doc, analyzer); doc = new Document(); doc.add(new Field(textFieldName, "More than 400 million people trust Google with their e-mail, and 50 million store files" + " in the cloud using the Dropbox service. People manage their bank accounts, pay bills, trade stocks and " + "generally transfer or store huge volumes of personal data online.", ft)); - doc.add(new Field(classFieldName, "technology", ft)); + doc.add(new Field(categoryFieldName, "technology", ft)); + doc.add(new Field(booleanFieldName, "true", ft)); indexWriter.addDocument(doc, analyzer); indexWriter.commit(); Index: lucene/classification/src/java/org/apache/lucene/classification/KNearestNeighborClassifier.java =================================================================== --- lucene/classification/src/java/org/apache/lucene/classification/KNearestNeighborClassifier.java (revision 1453398) +++ lucene/classification/src/java/org/apache/lucene/classification/KNearestNeighborClassifier.java (working copy) @@ -58,6 +58,9 @@ */ @Override public ClassificationResult assignClass(String text) throws IOException { + if (mlt == null) { + throw new IOException("You must first call Classifier#train first"); + } Query q = mlt.like(new StringReader(text), textFieldName); TopDocs topDocs = indexSearcher.search(q, k); return selectClassFromNeighbors(topDocs); Index: lucene/classification/src/java/org/apache/lucene/classification/SimpleNaiveBayesClassifier.java =================================================================== --- lucene/classification/src/java/org/apache/lucene/classification/SimpleNaiveBayesClassifier.java (revision 1453398) +++ lucene/classification/src/java/org/apache/lucene/classification/SimpleNaiveBayesClassifier.java (working copy) @@ -103,7 +103,7 @@ @Override public ClassificationResult assignClass(String inputDocument) throws IOException { if (atomicReader == null) { - throw new RuntimeException("need to train the classifier first"); + throw new IOException("You must first call Classifier#train first"); } double max = 0d; BytesRef foundClass = new BytesRef(); Index: lucene/classification/src/java/org/apache/lucene/classification/BooleanPerceptronClassifier.java =================================================================== --- lucene/classification/src/java/org/apache/lucene/classification/BooleanPerceptronClassifier.java (revision 0) +++ lucene/classification/src/java/org/apache/lucene/classification/BooleanPerceptronClassifier.java (revision 0) @@ -0,0 +1,161 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.lucene.classification; + +import org.apache.lucene.analysis.Analyzer; +import org.apache.lucene.analysis.TokenStream; +import org.apache.lucene.analysis.tokenattributes.CharTermAttribute; +import org.apache.lucene.index.AtomicReader; +import org.apache.lucene.index.MultiFields; +import org.apache.lucene.index.StorableField; +import org.apache.lucene.index.StoredDocument; +import org.apache.lucene.index.Terms; +import org.apache.lucene.index.TermsEnum; +import org.apache.lucene.search.IndexSearcher; +import org.apache.lucene.search.MatchAllDocsQuery; +import org.apache.lucene.search.ScoreDoc; +import org.apache.lucene.util.BytesRef; +import org.apache.lucene.util.IntsRef; +import org.apache.lucene.util.fst.Builder; +import org.apache.lucene.util.fst.FST; +import org.apache.lucene.util.fst.PositiveIntOutputs; +import org.apache.lucene.util.fst.Util; + +import java.io.IOException; +import java.io.StringReader; +import java.util.Map; +import java.util.SortedMap; +import java.util.TreeMap; + +/** + * A perceptron (see http://en.wikipedia.org/wiki/Perceptron) based + * Boolean {@link org.apache.lucene.classification.Classifier}. + * The weights are calculated using {@link org.apache.lucene.index.TermsEnum#totalTermFreq} + * both on a per field and a per document basis and then a corresponding {@link FST} is used for class assignment. + */ +public class BooleanPerceptronClassifier implements Classifier { + + private final Double threshold; + private Terms textTerms; + private Analyzer analyzer; + private String textFieldName; + private FST fst; + + /** + * Create a {@link BooleanPerceptronClassifier} + * + * @param threshold the binary threshold for perceptron output evaluation + */ + public BooleanPerceptronClassifier(Double threshold) { + this.threshold = threshold; + } + + /** + * {@inheritDoc} + */ + @Override + public ClassificationResult assignClass(String text) throws IOException { + if (textTerms == null) { + throw new IOException("You must first call Classifier#train first"); + } + Long output = 0l; + // TODO : make this a FST traversal + TokenStream tokenStream = analyzer.tokenStream(textFieldName, new StringReader(text)); + CharTermAttribute charTermAttribute = tokenStream.addAttribute(CharTermAttribute.class); + tokenStream.reset(); + while (tokenStream.incrementToken()) { + String s = charTermAttribute.toString(); + Long d = Util.get(fst, new BytesRef(s)); + if (d != null && d > 0) { + output += d; + } + } + tokenStream.end(); + tokenStream.close(); + + return new ClassificationResult(output >= threshold, output.doubleValue()); + } + + /** + * {@inheritDoc} + */ + @Override + public void train(AtomicReader atomicReader, String textFieldName, String classFieldName, Analyzer analyzer) throws IOException { + this.textTerms = MultiFields.getTerms(atomicReader, textFieldName); + this.analyzer = analyzer; + this.textFieldName = textFieldName; + + SortedMap weights = new TreeMap(); // this needs to be sorted to make FST update work + TermsEnum reuse = textTerms.iterator(null); + BytesRef textTerm; + while ((textTerm = reuse.next()) != null) { + weights.put(textTerm.utf8ToString(), (double) reuse.totalTermFreq()); + } + updateFST(weights); + + IndexSearcher indexSearcher = new IndexSearcher(atomicReader); + // for each doc + for (ScoreDoc scoreDoc : indexSearcher.search(new MatchAllDocsQuery(), Integer.MAX_VALUE).scoreDocs) { + StoredDocument doc = indexSearcher.doc(scoreDoc.doc); + + // assign class to the doc + ClassificationResult classificationResult = assignClass(doc.getField(textFieldName).stringValue()); + Boolean assignedClass = classificationResult.getAssignedClass(); + + // get the expected result + StorableField field = doc.getField(classFieldName); + + Boolean correctClass = Boolean.valueOf(field.stringValue()); + double modifier = correctClass.compareTo(assignedClass); + if (modifier != 0) { + TermsEnum cte = textTerms.iterator(reuse); + + // get the doc term vectors + Terms terms = atomicReader.getTermVector(scoreDoc.doc, textFieldName); + + TermsEnum termsEnum = terms.iterator(null); + + BytesRef term; + while ((term = termsEnum.next()) != null) { + cte.seekExact(term, true); + if (assignedClass != null) { + String termString = cte.term().utf8ToString(); + long termFreqLocal = termsEnum.totalTermFreq(); + // update weights + weights.put(termString, weights.get(termString) + modifier * termFreqLocal); + } + } + updateFST(weights); + reuse = cte; + } + } + weights.clear(); // free memory while waiting for GC + } + + private void updateFST(SortedMap weights) throws IOException { + PositiveIntOutputs outputs = PositiveIntOutputs.getSingleton(true); + Builder builder = new Builder(FST.INPUT_TYPE.BYTE1, outputs); + BytesRef scratchBytes = new BytesRef(); + IntsRef scratchInts = new IntsRef(); + for (Map.Entry entry : weights.entrySet()) { + scratchBytes.copyChars(entry.getKey()); + builder.add(Util.toIntsRef(scratchBytes, scratchInts), entry.getValue().longValue()); + } + fst = builder.finish(); + } + +} Property changes on: lucene/classification/src/java/org/apache/lucene/classification/BooleanPerceptronClassifier.java ___________________________________________________________________ Added: svn:eol-style + native