use of org.apache.stanbol.enhancer.topic.api.ClassifierException in project stanbol by apache.
the class TopicEngineTest method testCrossValidation.
@Test
public void testCrossValidation() throws Exception {
log.info(" --- testCrossValidation --- ");
// seed a pseudo random number generator for reproducible tests
Random rng = new Random(0);
ClassificationReport performanceEstimates;
// build an artificial data set used for training models and evaluation
int numberOfTopics = 10;
int numberOfDocuments = 100;
int vocabSizeMin = 20;
int vocabSizeMax = 30;
initArtificialTrainingSet(numberOfTopics, numberOfDocuments, vocabSizeMin, vocabSizeMax, rng);
// by default the reports are not computed
performanceEstimates = classifier.getPerformanceEstimates("urn:t/001");
assertFalse(performanceEstimates.uptodate);
performanceEstimates = classifier.getPerformanceEstimates("urn:t/002");
assertFalse(performanceEstimates.uptodate);
performanceEstimates = classifier.getPerformanceEstimates("urn:t/003");
assertFalse(performanceEstimates.uptodate);
try {
classifier.getPerformanceEstimates("urn:doesnotexist");
fail("Should have raised a ClassifierException");
} catch (ClassifierException e) {
// expected
}
// launch an evaluation of the classifier according to the current state of the training set
assertEquals(numberOfTopics, classifier.updatePerformanceEstimates(true));
for (int i = 1; i <= numberOfTopics; i++) {
String topic = String.format("urn:t/%03d", i);
performanceEstimates = classifier.getPerformanceEstimates(topic);
assertTrue(performanceEstimates.uptodate);
assertGreater(performanceEstimates.precision, 0.45f);
assertNotNull(performanceEstimates.falsePositiveExampleIds);
assertNotNull(performanceEstimates.falseNegativeExampleIds);
if (performanceEstimates.precision < 1) {
assertFalse(performanceEstimates.falsePositiveExampleIds.isEmpty());
}
if (performanceEstimates.recall < 1) {
assertFalse(performanceEstimates.falseNegativeExampleIds.isEmpty());
}
assertGreater(performanceEstimates.recall, 0.45f);
assertGreater(performanceEstimates.f1, 0.55f);
// very small support, hence the estimates are unstable, hence we set low min expectations, but we
// need this test to run reasonably fast...
assertGreater(performanceEstimates.positiveSupport, 4);
assertGreater(performanceEstimates.negativeSupport, 4);
assertNotNull(performanceEstimates.evaluationDate);
}
// TODO: test model invalidation by registering a sub topic manually
}
use of org.apache.stanbol.enhancer.topic.api.ClassifierException in project stanbol by apache.
the class TopicClassificationEngine method performCVFold.
protected int performCVFold(int cvFoldIndex, int cvFoldCount, int cvIterations, boolean incremental) throws ConfigurationException, TrainingSetException, ClassifierException {
cvIterations = cvIterations <= 0 ? cvFoldCount : cvFoldCount;
log.info(String.format("Performing evaluation %d-fold CV iteration %d/%d on classifier %s", cvFoldCount, cvFoldIndex + 1, cvIterations, engineName));
long start = System.currentTimeMillis();
final TopicClassificationEngine classifier = new TopicClassificationEngine();
try {
if (managedSolrServer != null) {
// OSGi setup: the evaluation server will be generated automatically using the
// managedSolrServer
classifier.bindManagedSolrServer(managedSolrServer);
classifier.activate(context, getCanonicalConfiguration(//TODO: maybe we should use the SolrCoreName instead
engineName + "-evaluation", solrCoreConfig));
} else {
if (__evaluationServer == null) {
__evaluationServerDir = new File(embeddedSolrServerDir, engineName + "-evaluation");
if (!__evaluationServerDir.exists()) {
FileUtils.forceMkdir(__evaluationServerDir);
}
__evaluationServer = EmbeddedSolrHelper.makeEmbeddedSolrServer(__evaluationServerDir, "evaluationclassifierserver", "default-topic-model", "default-topic-model");
}
classifier.configure(getCanonicalConfiguration(__evaluationServer, solrCoreConfig));
}
} catch (Exception e) {
throw new ClassifierException(e);
}
// clean all previous concepts from the evaluation classifier in case we are reusing an existing solr
// index from OSGi.
classifier.removeAllConcepts();
// iterate over all the topics to register them in the evaluation classifier
batchOverTopics(new BatchProcessor<SolrDocument>() {
@Override
public int process(List<SolrDocument> batch) throws ClassifierException {
for (SolrDocument topicEntry : batch) {
String conceptId = topicEntry.getFirstValue(conceptUriField).toString();
Collection<Object> broader = topicEntry.getFieldValues(broaderField);
if (broader == null) {
classifier.addConcept(conceptId, null, null);
} else {
List<String> broaderConcepts = new ArrayList<String>();
for (Object broaderConcept : broader) {
broaderConcepts.add(broaderConcept.toString());
}
classifier.addConcept(conceptId, null, broaderConcepts);
}
}
return batch.size();
}
});
// build the model on the for the current train CV folds
classifier.setCrossValidationInfo(cvFoldIndex, cvFoldCount);
// bind our new classifier to the same training set at the parent
classifier.setTrainingSet(getTrainingSet());
classifier.updateModel(false);
final int foldCount = cvFoldCount;
final int foldIndex = cvFoldIndex;
// iterate over the topics again to compute scores on the test fold
int updatedTopics = batchOverTopics(new BatchProcessor<SolrDocument>() {
@Override
public int process(List<SolrDocument> batch) throws TrainingSetException, ClassifierException {
int offset;
int updated = 0;
for (SolrDocument topicMetadata : batch) {
String topic = topicMetadata.getFirstValue(conceptUriField).toString();
List<String> topics = Arrays.asList(topic);
List<String> falseNegativeExamples = new ArrayList<String>();
int truePositives = 0;
int falseNegatives = 0;
int positiveSupport = 0;
offset = 0;
Batch<Example> examples = Batch.emtpyBatch(Example.class);
boolean skipTopic = false;
do {
examples = getTrainingSet().getPositiveExamples(topics, examples.nextOffset);
if (offset == 0 && examples.items.size() < MIN_EVALUATION_SAMPLES) {
// we need a minimum about of examples otherwise it's really not
// worth computing statistics
skipTopic = true;
break;
}
for (Example example : examples.items) {
if (!(offset % foldCount == foldIndex)) {
// this example is not part of the test fold, skip it
offset++;
continue;
}
positiveSupport++;
offset++;
List<TopicSuggestion> suggestedTopics = classifier.suggestTopics(example.contents);
boolean match = false;
for (TopicSuggestion suggestedTopic : suggestedTopics) {
if (topic.equals(suggestedTopic.conceptUri)) {
match = true;
truePositives++;
break;
}
}
if (!match) {
falseNegatives++;
if (falseNegativeExamples.size() < MAX_COLLECTED_EXAMPLES / foldCount) {
falseNegativeExamples.add(example.id);
}
}
}
} while (!skipTopic && examples.hasMore && offset < MAX_EVALUATION_SAMPLES);
List<String> falsePositiveExamples = new ArrayList<String>();
int falsePositives = 0;
int negativeSupport = 0;
offset = 0;
examples = Batch.emtpyBatch(Example.class);
do {
if (skipTopic) {
break;
}
examples = getTrainingSet().getNegativeExamples(topics, examples.nextOffset);
for (Example example : examples.items) {
if (!(offset % foldCount == foldIndex)) {
// this example is not part of the test fold, skip it
offset++;
continue;
}
negativeSupport++;
offset++;
List<TopicSuggestion> suggestedTopics = classifier.suggestTopics(example.contents);
for (TopicSuggestion suggestedTopic : suggestedTopics) {
if (topic.equals(suggestedTopic.conceptUri)) {
falsePositives++;
if (falsePositiveExamples.size() < MAX_COLLECTED_EXAMPLES / foldCount) {
falsePositiveExamples.add(example.id);
}
break;
}
}
// we don't need to collect true negatives
}
} while (examples.hasMore && offset < MAX_EVALUATION_SAMPLES);
if (skipTopic) {
log.debug("Skipping evaluation of {} because too few positive examples.", topic);
} else {
// compute precision, recall and f1 score for the current test fold and topic
float precision = 0;
if (truePositives != 0 || falsePositives != 0) {
precision = truePositives / (float) (truePositives + falsePositives);
}
float recall = 0;
if (truePositives != 0 || falseNegatives != 0) {
recall = truePositives / (float) (truePositives + falseNegatives);
}
updatePerformanceMetadata(topic, precision, recall, positiveSupport, negativeSupport, falsePositiveExamples, falseNegativeExamples);
updated += 1;
}
}
try {
getActiveSolrServer().commit();
} catch (Exception e) {
throw new ClassifierException(e);
}
return updated;
}
});
long stop = System.currentTimeMillis();
log.info(String.format("Finished CV iteration %d/%d on classifier %s in %fs.", cvFoldIndex + 1, cvFoldCount, engineName, (stop - start) / 1000.0));
if (context != null) {
// close open trackers
classifier.deactivate(context);
}
return updatedTopics;
}
use of org.apache.stanbol.enhancer.topic.api.ClassifierException in project stanbol by apache.
the class TopicClassificationEngine method removeAllConcepts.
@Override
public void removeAllConcepts() throws ClassifierException {
SolrServer solrServer = getActiveSolrServer();
try {
solrServer.deleteByQuery("*:*");
solrServer.commit();
} catch (Exception e) {
String msg = String.format("Error deleting concepts from Solr Core '%s'", solrCoreId);
throw new ClassifierException(msg, e);
}
}
use of org.apache.stanbol.enhancer.topic.api.ClassifierException in project stanbol by apache.
the class TopicClassificationEngine method suggestTopics.
public List<TopicSuggestion> suggestTopics(String text) throws ClassifierException {
List<TopicSuggestion> suggestedTopics = new ArrayList<TopicSuggestion>(MAX_SUGGESTIONS * 3);
SolrServer solrServer = getActiveSolrServer();
SolrQuery query = new SolrQuery();
query.setRequestHandler("/" + MoreLikeThisParams.MLT);
query.setFilterQueries(entryTypeField + ":" + MODEL_ENTRY);
query.set(MoreLikeThisParams.MATCH_INCLUDE, false);
query.set(MoreLikeThisParams.MIN_DOC_FREQ, 1);
query.set(MoreLikeThisParams.MIN_TERM_FREQ, 1);
query.set(MoreLikeThisParams.MAX_QUERY_TERMS, 30);
query.set(MoreLikeThisParams.MAX_NUM_TOKENS_PARSED, 10000);
// TODO: find a way to parse the interesting terms and report them
// for debugging / explanation in dedicated RDF data structure.
// query.set(MoreLikeThisParams.INTERESTING_TERMS, "details");
query.set(MoreLikeThisParams.SIMILARITY_FIELDS, similarityField);
query.set(CommonParams.STREAM_BODY, text);
// over query the number of suggestions to find a statistical cut based on the curve of the scores of
// the top suggestion
query.setRows(MAX_SUGGESTIONS * 3);
query.setFields(conceptUriField);
query.setIncludeScore(true);
try {
StreamQueryRequest request = new StreamQueryRequest(query);
QueryResponse response = request.process(solrServer);
SolrDocumentList results = response.getResults();
for (SolrDocument result : results.toArray(new SolrDocument[0])) {
String conceptUri = (String) result.getFirstValue(conceptUriField);
if (conceptUri == null) {
throw new ClassifierException(String.format("Solr Core '%s' is missing required field '%s'.", solrCoreId, conceptUriField));
}
Float score = (Float) result.getFirstValue("score");
// fetch metadata
SolrQuery metadataQuery = new SolrQuery("*:*");
// use filter queries to leverage the Solr cache explicitly
metadataQuery.addFilterQuery(entryTypeField + ":" + METADATA_ENTRY);
metadataQuery.addFilterQuery(conceptUriField + ":" + ClientUtils.escapeQueryChars(conceptUri));
metadataQuery.setFields(conceptUriField, broaderField, primaryTopicUriField);
SolrDocument metadata = solrServer.query(metadataQuery).getResults().get(0);
String primaryTopicUri = (String) metadata.getFirstValue(primaryTopicUriField);
suggestedTopics.add(new TopicSuggestion(conceptUri, primaryTopicUri, metadata.getFieldValues(broaderField), score));
}
} catch (SolrServerException e) {
if ("unknown handler: /mlt".equals(e.getCause().getMessage())) {
String message = String.format("SolrServer with id '%s' for topic engine '%s' lacks" + " configuration for the MoreLikeThisHandler", solrCoreId, engineName);
throw new ClassifierException(message, e);
} else {
throw new ClassifierException(e);
}
}
if (suggestedTopics.size() <= 1) {
// no need to apply the cutting heuristic
return suggestedTopics;
}
// filter out suggestions that are less than some threshold based on the mean of the top scores
float mean = 0.0f;
for (TopicSuggestion suggestion : suggestedTopics) {
mean += suggestion.score / suggestedTopics.size();
}
float threshold = 0.25f * suggestedTopics.get(0).score + 0.75f * mean;
List<TopicSuggestion> filteredSuggestions = new ArrayList<TopicSuggestion>();
for (TopicSuggestion suggestion : suggestedTopics) {
if (filteredSuggestions.size() >= MAX_SUGGESTIONS) {
return filteredSuggestions;
}
if (filteredSuggestions.isEmpty() || suggestion.score > threshold) {
filteredSuggestions.add(suggestion);
} else {
break;
}
}
return filteredSuggestions;
}
use of org.apache.stanbol.enhancer.topic.api.ClassifierException in project stanbol by apache.
the class TopicClassificationEngine method invalidateModelFields.
/*
* The commit is the responsibility of the caller.
*/
protected void invalidateModelFields(Collection<String> conceptIds, String... fieldNames) throws ClassifierException {
if (conceptIds.isEmpty() || fieldNames.length == 0) {
return;
}
SolrServer solrServer = getActiveSolrServer();
List<String> invalidatedFields = Arrays.asList(fieldNames);
try {
UpdateRequest request = new UpdateRequest();
for (String conceptId : conceptIds) {
SolrQuery query = new SolrQuery("*:*");
query.addFilterQuery(entryTypeField + ":" + METADATA_ENTRY);
query.addFilterQuery(conceptUriField + ":" + ClientUtils.escapeQueryChars(conceptId));
for (SolrDocument result : solrServer.query(query).getResults()) {
// there should be only one (or none: tolerated)
SolrInputDocument newEntry = new SolrInputDocument();
for (String fieldName : result.getFieldNames()) {
if (!invalidatedFields.contains(fieldName)) {
newEntry.setField(fieldName, result.getFieldValues(fieldName));
}
}
request.add(newEntry);
}
}
if (request.getDocuments() != null && request.getDocuments().size() > 0) {
solrServer.request(request);
}
} catch (Exception e) {
String msg = String.format("Error invalidating topics [%s] on Solr Core '%s'", StringUtils.join(conceptIds, ", "), solrCoreId);
throw new ClassifierException(msg, e);
}
}
Aggregations