Index: lucene/classification/src/test/org/apache/lucene/classification/ClassificationTestBase.java =================================================================== --- lucene/classification/src/test/org/apache/lucene/classification/ClassificationTestBase.java (revision 1427690) +++ lucene/classification/src/test/org/apache/lucene/classification/ClassificationTestBase.java (working copy) @@ -24,6 +24,7 @@ import org.apache.lucene.index.RandomIndexWriter; import org.apache.lucene.index.SlowCompositeReaderWrapper; import org.apache.lucene.store.Directory; +import org.apache.lucene.util.BytesRef; import org.apache.lucene.util.LuceneTestCase; import org.junit.After; import org.junit.Before; @@ -56,15 +57,17 @@ dir.close(); } - protected void checkCorrectClassification(Classifier classifier, Analyzer analyzer) throws Exception { + + protected void checkCorrectClassification(Classifier classifier, Analyzer analyzer) throws Exception { SlowCompositeReaderWrapper compositeReaderWrapper = null; try { populateIndex(analyzer); compositeReaderWrapper = new SlowCompositeReaderWrapper(indexWriter.getReader()); classifier.train(compositeReaderWrapper, textFieldName, classFieldName, analyzer); String newText = "Much is made of what the likes of Facebook, Google and Apple know about users. Truth is, Amazon may know more."; - ClassificationResult classificationResult = classifier.assignClass(newText); - assertEquals("technology", classificationResult.getAssignedClass()); + ClassificationResult classificationResult = classifier.assignClass(newText); + assertNotNull(classificationResult.getAssignedClass()); + assertEquals(new BytesRef("technology"), classificationResult.getAssignedClass()); assertTrue(classificationResult.getScore() > 0); } finally { if (compositeReaderWrapper != null) Index: lucene/classification/src/java/org/apache/lucene/classification/KNearestNeighborClassifier.java =================================================================== --- lucene/classification/src/java/org/apache/lucene/classification/KNearestNeighborClassifier.java (revision 1427690) +++ lucene/classification/src/java/org/apache/lucene/classification/KNearestNeighborClassifier.java (working copy) @@ -23,6 +23,7 @@ import org.apache.lucene.search.Query; import org.apache.lucene.search.ScoreDoc; import org.apache.lucene.search.TopDocs; +import org.apache.lucene.util.BytesRef; import java.io.IOException; import java.io.StringReader; @@ -35,7 +36,7 @@ * * @lucene.experimental */ -public class KNearestNeighborClassifier implements Classifier { +public class KNearestNeighborClassifier implements Classifier { private MoreLikeThis mlt; private String textFieldName; @@ -56,27 +57,29 @@ * {@inheritDoc} */ @Override - public ClassificationResult assignClass(String text) throws IOException { + public ClassificationResult assignClass(String text) throws IOException { Query q = mlt.like(new StringReader(text), textFieldName); TopDocs topDocs = indexSearcher.search(q, k); return selectClassFromNeighbors(topDocs); } - private ClassificationResult selectClassFromNeighbors(TopDocs topDocs) throws IOException { + private ClassificationResult selectClassFromNeighbors(TopDocs topDocs) throws IOException { // TODO : improve the nearest neighbor selection - Map classCounts = new HashMap(); + Map classCounts = new HashMap(); for (ScoreDoc scoreDoc : topDocs.scoreDocs) { - String cl = indexSearcher.doc(scoreDoc.doc).getField(classFieldName).stringValue(); - Integer count = classCounts.get(cl); - if (count != null) { - classCounts.put(cl, count + 1); - } else { - classCounts.put(cl, 1); + BytesRef cl = new BytesRef(indexSearcher.doc(scoreDoc.doc).getField(classFieldName).stringValue()); + if (cl != null) { + Integer count = classCounts.get(cl); + if (count != null) { + classCounts.put(cl, count + 1); + } else { + classCounts.put(cl, 1); + } } } double max = 0; - String assignedClass = null; - for (String cl : classCounts.keySet()) { + BytesRef assignedClass = new BytesRef(); + for (BytesRef cl : classCounts.keySet()) { Integer count = classCounts.get(cl); if (count > max) { max = count; @@ -84,7 +87,7 @@ } } double score = max / (double) k; - return new ClassificationResult(assignedClass, score); + return new ClassificationResult(assignedClass, score); } /** Index: lucene/classification/src/java/org/apache/lucene/classification/Classifier.java =================================================================== --- lucene/classification/src/java/org/apache/lucene/classification/Classifier.java (revision 1427690) +++ lucene/classification/src/java/org/apache/lucene/classification/Classifier.java (working copy) @@ -22,18 +22,19 @@ import java.io.IOException; /** - * A classifier, see http://en.wikipedia.org/wiki/Classifier_(mathematics) + * A classifier, see http://en.wikipedia.org/wiki/Classifier_(mathematics), which assign classes of type + * T * @lucene.experimental */ -public interface Classifier { +public interface Classifier { /** * Assign a class (with score) to the given text String * @param text a String containing text to be classified - * @return a {@link ClassificationResult} holding assigned class and score + * @return a {@link ClassificationResult} holding assigned class of type T and score * @throws IOException If there is a low-level I/O error. */ - public ClassificationResult assignClass(String text) throws IOException; + public ClassificationResult assignClass(String text) throws IOException; /** * Train the classifier using the underlying Lucene index Index: lucene/classification/src/java/org/apache/lucene/classification/SimpleNaiveBayesClassifier.java =================================================================== --- lucene/classification/src/java/org/apache/lucene/classification/SimpleNaiveBayesClassifier.java (revision 1427690) +++ lucene/classification/src/java/org/apache/lucene/classification/SimpleNaiveBayesClassifier.java (working copy) @@ -41,7 +41,7 @@ * * @lucene.experimental */ -public class SimpleNaiveBayesClassifier implements Classifier { +public class SimpleNaiveBayesClassifier implements Classifier { private AtomicReader atomicReader; private String textFieldName; @@ -89,12 +89,12 @@ * {@inheritDoc} */ @Override - public ClassificationResult assignClass(String inputDocument) throws IOException { + public ClassificationResult assignClass(String inputDocument) throws IOException { if (atomicReader == null) { throw new RuntimeException("need to train the classifier first"); } double max = 0d; - String foundClass = null; + BytesRef foundClass = new BytesRef(); Terms terms = MultiFields.getTerms(atomicReader, classFieldName); TermsEnum termsEnum = terms.iterator(null); @@ -105,10 +105,10 @@ double clVal = calculatePrior(next) * calculateLikelihood(tokenizedDoc, next); if (clVal > max) { max = clVal; - foundClass = next.utf8ToString(); + foundClass = next; } } - return new ClassificationResult(foundClass, max); + return new ClassificationResult(foundClass, max); } Index: lucene/classification/src/java/org/apache/lucene/classification/ClassificationResult.java =================================================================== --- lucene/classification/src/java/org/apache/lucene/classification/ClassificationResult.java (revision 1427690) +++ lucene/classification/src/java/org/apache/lucene/classification/ClassificationResult.java (working copy) @@ -17,29 +17,29 @@ package org.apache.lucene.classification; /** - * The result of a call to {@link Classifier#assignClass(String)} holding an assigned class and a score. + * The result of a call to {@link Classifier#assignClass(String)} holding an assigned class of type T and a score. * @lucene.experimental */ -public class ClassificationResult { +public class ClassificationResult { - private String assignedClass; + private T assignedClass; private double score; /** * Constructor - * @param assignedClass the class String assigned by a {@link Classifier} + * @param assignedClass the class T assigned by a {@link Classifier} * @param score the score for the assignedClass as a double */ - public ClassificationResult(String assignedClass, double score) { + public ClassificationResult(T assignedClass, double score) { this.assignedClass = assignedClass; this.score = score; } /** * retrieve the result class - * @return a String representing an assigned class + * @return a T representing an assigned class */ - public String getAssignedClass() { + public T getAssignedClass() { return assignedClass; }