use of org.apache.lucene.classification.ClassificationResult in project lucene-solr by apache.
the class SimpleNaiveBayesDocumentClassifier method assignNormClasses.
private List<ClassificationResult<BytesRef>> assignNormClasses(Document inputDocument) throws IOException {
List<ClassificationResult<BytesRef>> assignedClasses = new ArrayList<>();
Map<String, List<String[]>> fieldName2tokensArray = new LinkedHashMap<>();
Map<String, Float> fieldName2boost = new LinkedHashMap<>();
Terms classes = MultiFields.getTerms(indexReader, classFieldName);
TermsEnum classesEnum = classes.iterator();
BytesRef c;
analyzeSeedDocument(inputDocument, fieldName2tokensArray, fieldName2boost);
int docsWithClassSize = countDocsWithClass();
while ((c = classesEnum.next()) != null) {
double classScore = 0;
Term term = new Term(this.classFieldName, c);
for (String fieldName : textFieldNames) {
List<String[]> tokensArrays = fieldName2tokensArray.get(fieldName);
double fieldScore = 0;
for (String[] fieldTokensArray : tokensArrays) {
fieldScore += calculateLogPrior(term, docsWithClassSize) + calculateLogLikelihood(fieldTokensArray, fieldName, term, docsWithClassSize) * fieldName2boost.get(fieldName);
}
classScore += fieldScore;
}
assignedClasses.add(new ClassificationResult<>(term.bytes(), classScore));
}
return normClassificationResults(assignedClasses);
}
use of org.apache.lucene.classification.ClassificationResult in project lucene-solr by apache.
the class KNearestNeighborDocumentClassifier method getClasses.
/**
* {@inheritDoc}
*/
@Override
public List<ClassificationResult<BytesRef>> getClasses(Document document, int max) throws IOException {
TopDocs knnResults = knnSearch(document);
List<ClassificationResult<BytesRef>> assignedClasses = buildListFromTopDocs(knnResults);
Collections.sort(assignedClasses);
max = Math.min(max, assignedClasses.size());
return assignedClasses.subList(0, max);
}
use of org.apache.lucene.classification.ClassificationResult in project lucene-solr by apache.
the class KNearestNeighborDocumentClassifier method getClasses.
/**
* {@inheritDoc}
*/
@Override
public List<ClassificationResult<BytesRef>> getClasses(Document document) throws IOException {
TopDocs knnResults = knnSearch(document);
List<ClassificationResult<BytesRef>> assignedClasses = buildListFromTopDocs(knnResults);
Collections.sort(assignedClasses);
return assignedClasses;
}
use of org.apache.lucene.classification.ClassificationResult in project lucene-solr by apache.
the class ConfusionMatrixGeneratorTest method testGetConfusionMatrix.
@Test
public void testGetConfusionMatrix() throws Exception {
LeafReader reader = null;
try {
MockAnalyzer analyzer = new MockAnalyzer(random());
reader = getSampleIndex(analyzer);
Classifier<BytesRef> classifier = new Classifier<BytesRef>() {
@Override
public ClassificationResult<BytesRef> assignClass(String text) throws IOException {
return new ClassificationResult<>(new BytesRef(), 1 / (1 + Math.exp(-random().nextInt())));
}
@Override
public List<ClassificationResult<BytesRef>> getClasses(String text) throws IOException {
return null;
}
@Override
public List<ClassificationResult<BytesRef>> getClasses(String text, int max) throws IOException {
return null;
}
};
ConfusionMatrixGenerator.ConfusionMatrix confusionMatrix = ConfusionMatrixGenerator.getConfusionMatrix(reader, classifier, categoryFieldName, textFieldName, -1);
assertNotNull(confusionMatrix);
assertNotNull(confusionMatrix.getLinearizedMatrix());
assertEquals(7, confusionMatrix.getNumberOfEvaluatedDocs());
double avgClassificationTime = confusionMatrix.getAvgClassificationTime();
assertTrue(avgClassificationTime >= 0d);
double accuracy = confusionMatrix.getAccuracy();
assertTrue(accuracy >= 0d);
assertTrue(accuracy <= 1d);
double precision = confusionMatrix.getPrecision();
assertTrue(precision >= 0d);
assertTrue(precision <= 1d);
double recall = confusionMatrix.getRecall();
assertTrue(recall >= 0d);
assertTrue(recall <= 1d);
double f1Measure = confusionMatrix.getF1Measure();
assertTrue(f1Measure >= 0d);
assertTrue(f1Measure <= 1d);
} finally {
if (reader != null) {
reader.close();
}
}
}
use of org.apache.lucene.classification.ClassificationResult in project lucene-solr by apache.
the class ClassificationUpdateProcessor method processAdd.
/**
* @param cmd the update command in input containing the Document to classify
* @throws IOException If there is a low-level I/O error
*/
@Override
public void processAdd(AddUpdateCommand cmd) throws IOException {
SolrInputDocument doc = cmd.getSolrInputDocument();
Document luceneDocument = cmd.getLuceneDocument();
String assignedClass;
Object documentClass = doc.getFieldValue(trainingClassField);
if (documentClass == null) {
List<ClassificationResult<BytesRef>> assignedClassifications = classifier.getClasses(luceneDocument, maxOutputClasses);
if (assignedClassifications != null) {
for (ClassificationResult<BytesRef> singleClassification : assignedClassifications) {
assignedClass = singleClassification.getAssignedClass().utf8ToString();
doc.addField(predictedClassField, assignedClass);
}
}
}
super.processAdd(cmd);
}
Aggregations