use of org.apache.lucene.classification.Classifier in project lucene-solr by apache.
the class ConfusionMatrixGeneratorTest method testGetConfusionMatrix.
@Test
public void testGetConfusionMatrix() throws Exception {
LeafReader reader = null;
try {
MockAnalyzer analyzer = new MockAnalyzer(random());
reader = getSampleIndex(analyzer);
Classifier<BytesRef> classifier = new Classifier<BytesRef>() {
@Override
public ClassificationResult<BytesRef> assignClass(String text) throws IOException {
return new ClassificationResult<>(new BytesRef(), 1 / (1 + Math.exp(-random().nextInt())));
}
@Override
public List<ClassificationResult<BytesRef>> getClasses(String text) throws IOException {
return null;
}
@Override
public List<ClassificationResult<BytesRef>> getClasses(String text, int max) throws IOException {
return null;
}
};
ConfusionMatrixGenerator.ConfusionMatrix confusionMatrix = ConfusionMatrixGenerator.getConfusionMatrix(reader, classifier, categoryFieldName, textFieldName, -1);
assertNotNull(confusionMatrix);
assertNotNull(confusionMatrix.getLinearizedMatrix());
assertEquals(7, confusionMatrix.getNumberOfEvaluatedDocs());
double avgClassificationTime = confusionMatrix.getAvgClassificationTime();
assertTrue(avgClassificationTime >= 0d);
double accuracy = confusionMatrix.getAccuracy();
assertTrue(accuracy >= 0d);
assertTrue(accuracy <= 1d);
double precision = confusionMatrix.getPrecision();
assertTrue(precision >= 0d);
assertTrue(precision <= 1d);
double recall = confusionMatrix.getRecall();
assertTrue(recall >= 0d);
assertTrue(recall <= 1d);
double f1Measure = confusionMatrix.getF1Measure();
assertTrue(f1Measure >= 0d);
assertTrue(f1Measure <= 1d);
} finally {
if (reader != null) {
reader.close();
}
}
}
Aggregations