Index: lucene/classification/ivy.xml =================================================================== --- lucene/classification/ivy.xml (revision 0) +++ lucene/classification/ivy.xml (revision 0) @@ -0,0 +1,21 @@ + + + + Index: lucene/classification/src/test/org/apache/lucene/classification/SimpleNaiveBayesClassifierTest.java =================================================================== --- lucene/classification/src/test/org/apache/lucene/classification/SimpleNaiveBayesClassifierTest.java (revision 0) +++ lucene/classification/src/test/org/apache/lucene/classification/SimpleNaiveBayesClassifierTest.java (revision 0) @@ -0,0 +1,130 @@ +package org.apache.lucene.classification; + +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import org.apache.lucene.analysis.Analyzer; +import org.apache.lucene.analysis.MockAnalyzer; +import org.apache.lucene.document.Document; +import org.apache.lucene.document.Field; +import org.apache.lucene.document.TextField; +import org.apache.lucene.index.RandomIndexWriter; +import org.apache.lucene.index.SlowCompositeReaderWrapper; +import org.apache.lucene.store.Directory; +import org.apache.lucene.util.LuceneTestCase; +import org.junit.After; +import org.junit.Before; +import org.junit.Test; + +/** + * Testcase for {@link SimpleNaiveBayesClassifier} + */ +public class SimpleNaiveBayesClassifierTest extends LuceneTestCase { + + private RandomIndexWriter indexWriter; + private String textFieldName; + private String classFieldName; + private Analyzer analyzer; + private Directory dir; + + @Before + public void setUp() throws Exception { + super.setUp(); + analyzer = new MockAnalyzer(random()); + dir = newDirectory(); + indexWriter = new RandomIndexWriter(random(), dir); + textFieldName = "text"; + classFieldName = "cat"; + } + + @After + public void tearDown() throws Exception { + super.tearDown(); + indexWriter.close(); + dir.close(); + } + + @Test + public void testBasicUsage() throws Exception { + SlowCompositeReaderWrapper compositeReaderWrapper = null; + try { + populateIndex(); + SimpleNaiveBayesClassifier simpleNaiveBayesClassifier = new SimpleNaiveBayesClassifier(); + compositeReaderWrapper = new SlowCompositeReaderWrapper(indexWriter.getReader()); + simpleNaiveBayesClassifier.train(compositeReaderWrapper, textFieldName, classFieldName, analyzer); + String newText = "Much is made of what the likes of Facebook, Google and Apple know about users. Truth is, Amazon may know more. "; + assertEquals("technology", simpleNaiveBayesClassifier.assignClass(newText)); + } finally { + if (compositeReaderWrapper != null) + compositeReaderWrapper.close(); + } + } + + private void populateIndex() throws Exception { + + Document doc = new Document(); + doc.add(new TextField(textFieldName, "The traveling press secretary for Mitt Romney lost his cool and cursed at reporters " + + "who attempted to ask questions of the Republican presidential candidate in a public plaza near the Tomb of " + + "the Unknown Soldier in Warsaw Tuesday.", Field.Store.YES)); + doc.add(new TextField(classFieldName, "politics", Field.Store.YES)); + + indexWriter.addDocument(doc, analyzer); + + doc = new Document(); + doc.add(new TextField(textFieldName, "Mitt Romney seeks to assure Israel and Iran, as well as Jewish voters in the United" + + " States, that he will be tougher against Iran's nuclear ambitions than President Barack Obama.", Field.Store.YES)); + doc.add(new TextField(classFieldName, "politics", Field.Store.YES)); + indexWriter.addDocument(doc, analyzer); + + doc = new Document(); + doc.add(new TextField(textFieldName, "And there's a threshold question that he has to answer for the American people and " + + "that's whether he is prepared to be commander-in-chief,\" she continued. \"As we look to the past events, we " + + "know that this raises some questions about his preparedness and we'll see how the rest of his trip goes.\"", Field.Store.YES)); + doc.add(new TextField(classFieldName, "politics", Field.Store.YES)); + indexWriter.addDocument(doc, analyzer); + + doc = new Document(); + doc.add(new TextField(textFieldName, "Still, when it comes to gun policy, many congressional Democrats have \"decided to " + + "keep quiet and not go there,\" said Alan Lizotte, dean and professor at the State University of New York at " + + "Albany's School of Criminal Justice.", Field.Store.YES)); + doc.add(new TextField(classFieldName, "politics", Field.Store.YES)); + indexWriter.addDocument(doc, analyzer); + + doc = new Document(); + doc.add(new TextField(textFieldName, "Standing amongst the thousands of people at the state Capitol, Jorstad, director of " + + "technology at the University of Wisconsin-La Crosse, documented the historic moment and shared it with the " + + "world through the Internet.", Field.Store.YES)); + doc.add(new TextField(classFieldName, "technology", Field.Store.YES)); + indexWriter.addDocument(doc, analyzer); + + doc = new Document(); + doc.add(new TextField(textFieldName, "So, about all those experts and analysts who've spent the past year or so saying " + + "Facebook was going to make a phone. A new expert has stepped forward to say it's not going to happen.", Field.Store.YES)); + doc.add(new TextField(classFieldName, "technology", Field.Store.YES)); + indexWriter.addDocument(doc, analyzer); + + doc = new Document(); + doc.add(new TextField(textFieldName, "More than 400 million people trust Google with their e-mail, and 50 million store files" + + " in the cloud using the Dropbox service. People manage their bank accounts, pay bills, trade stocks and " + + "generally transfer or store huge volumes of personal data online.", Field.Store.YES)); + doc.add(new TextField(classFieldName, "technology", Field.Store.YES)); + indexWriter.addDocument(doc, analyzer); + + indexWriter.commit(); + } + +} Index: lucene/classification/src/java/org/apache/lucene/classification/Classifier.java =================================================================== --- lucene/classification/src/java/org/apache/lucene/classification/Classifier.java (revision 0) +++ lucene/classification/src/java/org/apache/lucene/classification/Classifier.java (revision 0) @@ -0,0 +1,47 @@ +package org.apache.lucene.classification; + +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import org.apache.lucene.analysis.Analyzer; +import org.apache.lucene.index.AtomicReader; + +/** + * A classifier + */ +public interface Classifier { + + /** + * Assigns a class to the given text String + * @param text a String containing text to be classified + * @return a String representing a class + * @throws ClassificationException + */ + public String assignClass(String text) throws ClassificationException; + + /** + * + * @param atomicReader + * @param textFieldName + * @param classFieldName + * @param analyzer + * @throws ClassificationException + */ + public void train(AtomicReader atomicReader, String textFieldName, String classFieldName, Analyzer analyzer) + throws ClassificationException; + +} Index: lucene/classification/src/java/org/apache/lucene/classification/SimpleNaiveBayesClassifier.java =================================================================== --- lucene/classification/src/java/org/apache/lucene/classification/SimpleNaiveBayesClassifier.java (revision 0) +++ lucene/classification/src/java/org/apache/lucene/classification/SimpleNaiveBayesClassifier.java (revision 0) @@ -0,0 +1,152 @@ +package org.apache.lucene.classification; + +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import org.apache.lucene.analysis.Analyzer; +import org.apache.lucene.analysis.TokenStream; +import org.apache.lucene.analysis.tokenattributes.CharTermAttribute; +import org.apache.lucene.index.AtomicReader; +import org.apache.lucene.index.MultiFields; +import org.apache.lucene.index.Term; +import org.apache.lucene.index.Terms; +import org.apache.lucene.index.TermsEnum; +import org.apache.lucene.search.BooleanClause; +import org.apache.lucene.search.BooleanQuery; +import org.apache.lucene.search.IndexSearcher; +import org.apache.lucene.search.TermQuery; +import org.apache.lucene.util.BytesRef; + +import java.io.IOException; +import java.io.StringReader; +import java.util.Collection; +import java.util.HashMap; +import java.util.LinkedList; +import java.util.Map; + +/** + * A simplistic Lucene based NaiveBayes classifier + */ +public class SimpleNaiveBayesClassifier implements Classifier { + + private AtomicReader atomicReader; + private String textFieldName; + private String classFieldName; + private int docsWithClassSize; + private Analyzer analyzer; + private IndexSearcher indexSearcher; + + public void train(AtomicReader atomicReader, String textFieldName, String classFieldName, Analyzer analyzer) + throws ClassificationException { + this.atomicReader = atomicReader; + this.indexSearcher = new IndexSearcher(this.atomicReader); + this.textFieldName = textFieldName; + this.classFieldName = classFieldName; + this.analyzer = analyzer; + try { + docsWithClassSize = MultiFields.getTerms(this.atomicReader, this.classFieldName).getDocCount(); + } catch (IOException e) { + throw new ClassificationException(e); + } + } + + private String[] tokenizeDoc(String doc) throws IOException { + Collection result = new LinkedList(); + TokenStream tokenStream = analyzer.tokenStream(textFieldName, new StringReader(doc)); + CharTermAttribute charTermAttribute = tokenStream.addAttribute(CharTermAttribute.class); + tokenStream.reset(); + while (tokenStream.incrementToken()) { + result.add(charTermAttribute.toString()); + } + tokenStream.end(); + tokenStream.close(); + return result.toArray(new String[result.size()]); + } + + public String assignClass(String inputDocument) throws ClassificationException { + if (atomicReader == null) { + throw new RuntimeException("need to train the classifier first"); + } + Double max = 0d; + String foundClass = null; + + try { + Terms terms = MultiFields.getTerms(atomicReader, classFieldName); + TermsEnum termsEnum = terms.iterator(null); + BytesRef t = termsEnum.next(); + while (t != null) { + String classValue = t.utf8ToString(); + // TODO : turn it to be in log scale + Double clVal = calculatePrior(classValue) * calculateLikelihood(inputDocument, classValue); + if (clVal > max) { + max = clVal; + foundClass = classValue; + } + t = termsEnum.next(); + } + } catch (Exception e) { + throw new ClassificationException(e); + } + return foundClass; + } + + + private Double calculateLikelihood(String document, String c) throws IOException { + // for each word + Double result = 1d; + for (String word : tokenizeDoc(document)) { + // search with text:word AND class:c + int hits = getWordFreqForClass(word, c); + + // num : count the no of times the word appears in documents of class c (+1) + double num = hits + 1; // +1 is added because of add 1 smoothing + + // den : for the whole dictionary, count the no of times a word appears in documents of class c (+|V|) + double den = getTextTermFreqForClass(c) + docsWithClassSize; + + // P(w|c) = num/den + double wordProbability = num / den; + result *= wordProbability; + } + + // P(d|c) = P(w1|c)*...*P(wn|c) + return result; + } + + private double getTextTermFreqForClass(String c) throws IOException { + Terms terms = MultiFields.getTerms(atomicReader, textFieldName); + long numPostings = terms.getSumDocFreq(); // number of term/doc pairs + double avgNumberOfUniqueTerms = numPostings / (double) terms.getDocCount(); // avg # of unique terms per doc + int docsWithC = atomicReader.docFreq(classFieldName, new BytesRef(c)); + return avgNumberOfUniqueTerms * docsWithC; // avg # of unique terms in text field per doc * # docs with c + } + + private int getWordFreqForClass(String word, String c) throws IOException { + BooleanQuery booleanQuery = new BooleanQuery(); + booleanQuery.add(new BooleanClause(new TermQuery(new Term(textFieldName, word)), BooleanClause.Occur.MUST)); + booleanQuery.add(new BooleanClause(new TermQuery(new Term(classFieldName, c)), BooleanClause.Occur.MUST)); + return indexSearcher.search(booleanQuery, 1).totalHits; + } + + private Double calculatePrior(String currentClass) throws IOException { + return (double) docCount(currentClass) / docsWithClassSize; + } + + private int docCount(String countedClass) throws IOException { + return atomicReader.docFreq(new Term(classFieldName, countedClass)); + } +} Index: lucene/classification/src/java/org/apache/lucene/classification/ClassificationException.java =================================================================== --- lucene/classification/src/java/org/apache/lucene/classification/ClassificationException.java (revision 0) +++ lucene/classification/src/java/org/apache/lucene/classification/ClassificationException.java (revision 0) @@ -0,0 +1,36 @@ +package org.apache.lucene.classification; + +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** + * An {@link Exception} thrown if any errors occurs within a {@link Classifier} + */ +public class ClassificationException extends Exception { + + public ClassificationException(String s) { + super(s); + } + + public ClassificationException(String s, Throwable throwable) { + super(s, throwable); + } + + public ClassificationException(Throwable throwable) { + super(throwable); + } +} Index: lucene/classification/build.xml =================================================================== --- lucene/classification/build.xml (revision 0) +++ lucene/classification/build.xml (revision 0) @@ -0,0 +1,26 @@ + + + + + + + Classification module for Lucene + + + +