use of org.apache.stanbol.enhancer.topic.api.training.Example in project stanbol by apache.
the class TopicClassificationEngine method performCVFold.
protected int performCVFold(int cvFoldIndex, int cvFoldCount, int cvIterations, boolean incremental) throws ConfigurationException, TrainingSetException, ClassifierException {
cvIterations = cvIterations <= 0 ? cvFoldCount : cvFoldCount;
log.info(String.format("Performing evaluation %d-fold CV iteration %d/%d on classifier %s", cvFoldCount, cvFoldIndex + 1, cvIterations, engineName));
long start = System.currentTimeMillis();
final TopicClassificationEngine classifier = new TopicClassificationEngine();
try {
if (managedSolrServer != null) {
// OSGi setup: the evaluation server will be generated automatically using the
// managedSolrServer
classifier.bindManagedSolrServer(managedSolrServer);
classifier.activate(context, getCanonicalConfiguration(// TODO: maybe we should use the SolrCoreName instead
engineName + "-evaluation", solrCoreConfig));
} else {
if (__evaluationServer == null) {
__evaluationServerDir = new File(embeddedSolrServerDir, engineName + "-evaluation");
if (!__evaluationServerDir.exists()) {
FileUtils.forceMkdir(__evaluationServerDir);
}
__evaluationServer = EmbeddedSolrHelper.makeEmbeddedSolrServer(__evaluationServerDir, "evaluationclassifierserver", "default-topic-model", "default-topic-model");
}
classifier.configure(getCanonicalConfiguration(__evaluationServer, solrCoreConfig));
}
} catch (Exception e) {
throw new ClassifierException(e);
}
// clean all previous concepts from the evaluation classifier in case we are reusing an existing solr
// index from OSGi.
classifier.removeAllConcepts();
// iterate over all the topics to register them in the evaluation classifier
batchOverTopics(new BatchProcessor<SolrDocument>() {
@Override
public int process(List<SolrDocument> batch) throws ClassifierException {
for (SolrDocument topicEntry : batch) {
String conceptId = topicEntry.getFirstValue(conceptUriField).toString();
Collection<Object> broader = topicEntry.getFieldValues(broaderField);
if (broader == null) {
classifier.addConcept(conceptId, null, null);
} else {
List<String> broaderConcepts = new ArrayList<String>();
for (Object broaderConcept : broader) {
broaderConcepts.add(broaderConcept.toString());
}
classifier.addConcept(conceptId, null, broaderConcepts);
}
}
return batch.size();
}
});
// build the model on the for the current train CV folds
classifier.setCrossValidationInfo(cvFoldIndex, cvFoldCount);
// bind our new classifier to the same training set at the parent
classifier.setTrainingSet(getTrainingSet());
classifier.updateModel(false);
final int foldCount = cvFoldCount;
final int foldIndex = cvFoldIndex;
// iterate over the topics again to compute scores on the test fold
int updatedTopics = batchOverTopics(new BatchProcessor<SolrDocument>() {
@Override
public int process(List<SolrDocument> batch) throws TrainingSetException, ClassifierException {
int offset;
int updated = 0;
for (SolrDocument topicMetadata : batch) {
String topic = topicMetadata.getFirstValue(conceptUriField).toString();
List<String> topics = Arrays.asList(topic);
List<String> falseNegativeExamples = new ArrayList<String>();
int truePositives = 0;
int falseNegatives = 0;
int positiveSupport = 0;
offset = 0;
Batch<Example> examples = Batch.emtpyBatch(Example.class);
boolean skipTopic = false;
do {
examples = getTrainingSet().getPositiveExamples(topics, examples.nextOffset);
if (offset == 0 && examples.items.size() < MIN_EVALUATION_SAMPLES) {
// we need a minimum about of examples otherwise it's really not
// worth computing statistics
skipTopic = true;
break;
}
for (Example example : examples.items) {
if (!(offset % foldCount == foldIndex)) {
// this example is not part of the test fold, skip it
offset++;
continue;
}
positiveSupport++;
offset++;
List<TopicSuggestion> suggestedTopics = classifier.suggestTopics(example.contents);
boolean match = false;
for (TopicSuggestion suggestedTopic : suggestedTopics) {
if (topic.equals(suggestedTopic.conceptUri)) {
match = true;
truePositives++;
break;
}
}
if (!match) {
falseNegatives++;
if (falseNegativeExamples.size() < MAX_COLLECTED_EXAMPLES / foldCount) {
falseNegativeExamples.add(example.id);
}
}
}
} while (!skipTopic && examples.hasMore && offset < MAX_EVALUATION_SAMPLES);
List<String> falsePositiveExamples = new ArrayList<String>();
int falsePositives = 0;
int negativeSupport = 0;
offset = 0;
examples = Batch.emtpyBatch(Example.class);
do {
if (skipTopic) {
break;
}
examples = getTrainingSet().getNegativeExamples(topics, examples.nextOffset);
for (Example example : examples.items) {
if (!(offset % foldCount == foldIndex)) {
// this example is not part of the test fold, skip it
offset++;
continue;
}
negativeSupport++;
offset++;
List<TopicSuggestion> suggestedTopics = classifier.suggestTopics(example.contents);
for (TopicSuggestion suggestedTopic : suggestedTopics) {
if (topic.equals(suggestedTopic.conceptUri)) {
falsePositives++;
if (falsePositiveExamples.size() < MAX_COLLECTED_EXAMPLES / foldCount) {
falsePositiveExamples.add(example.id);
}
break;
}
}
// we don't need to collect true negatives
}
} while (examples.hasMore && offset < MAX_EVALUATION_SAMPLES);
if (skipTopic) {
log.debug("Skipping evaluation of {} because too few positive examples.", topic);
} else {
// compute precision, recall and f1 score for the current test fold and topic
float precision = 0;
if (truePositives != 0 || falsePositives != 0) {
precision = truePositives / (float) (truePositives + falsePositives);
}
float recall = 0;
if (truePositives != 0 || falseNegatives != 0) {
recall = truePositives / (float) (truePositives + falseNegatives);
}
updatePerformanceMetadata(topic, precision, recall, positiveSupport, negativeSupport, falsePositiveExamples, falseNegativeExamples);
updated += 1;
}
}
try {
getActiveSolrServer().commit();
} catch (Exception e) {
throw new ClassifierException(e);
}
return updated;
}
});
long stop = System.currentTimeMillis();
log.info(String.format("Finished CV iteration %d/%d on classifier %s in %fs.", cvFoldIndex + 1, cvFoldCount, engineName, (stop - start) / 1000.0));
if (context != null) {
// close open trackers
classifier.deactivate(context);
}
return updatedTopics;
}
use of org.apache.stanbol.enhancer.topic.api.training.Example in project stanbol by apache.
the class SolrTrainingSet method getExamples.
protected Batch<Example> getExamples(List<String> topics, Object offset, boolean positive) throws TrainingSetException {
List<Example> items = new ArrayList<Example>();
SolrServer solrServer = getActiveSolrServer();
SolrQuery query = new SolrQuery();
List<String> parts = new ArrayList<String>();
String q = "";
if (topics.isEmpty()) {
q += "*:*";
} else if (positive) {
for (String topic : topics) {
parts.add(topicUrisField + ":" + ClientUtils.escapeQueryChars(topic));
}
if (offset != null) {
q += "(";
}
q += StringUtils.join(parts, " OR ");
if (offset != null) {
q += ")";
}
} else {
for (String topic : topics) {
parts.add("-" + topicUrisField + ":" + ClientUtils.escapeQueryChars(topic));
}
q += StringUtils.join(parts, " AND ");
}
if (offset != null) {
q += " AND " + exampleIdField + ":[" + offset.toString() + " TO *]";
}
query.setQuery(q);
query.addSortField(exampleIdField, SolrQuery.ORDER.asc);
query.set("rows", batchSize + 1);
String nextExampleId = null;
try {
int count = 0;
QueryResponse response = solrServer.query(query);
for (SolrDocument result : response.getResults()) {
if (count == batchSize) {
nextExampleId = result.getFirstValue(exampleIdField).toString();
} else {
count++;
String exampleId = result.getFirstValue(exampleIdField).toString();
Collection<Object> labelValues = result.getFieldValues(topicUrisField);
Collection<Object> textValues = result.getFieldValues(exampleTextField);
if (textValues == null) {
continue;
}
items.add(new Example(exampleId, labelValues, textValues));
}
}
} catch (SolrServerException e) {
String msg = String.format("Error while fetching positive examples for topics ['%s'] on Solr Core '%s'.", StringUtils.join(topics, "', '"), solrCoreId);
throw new TrainingSetException(msg, e);
}
return new Batch<Example>(items, nextExampleId != null, nextExampleId);
}
use of org.apache.stanbol.enhancer.topic.api.training.Example in project stanbol by apache.
the class TrainingSetTest method testBatchingPositiveExamples.
@Test
public void testBatchingPositiveExamples() throws ConfigurationException, TrainingSetException {
log.info(" --- testBatchingPositiveExamples --- ");
Set<String> expectedCollectedIds = new HashSet<String>();
Set<String> expectedCollectedText = new HashSet<String>();
Set<String> collectedIds = new HashSet<String>();
Set<String> collectedText = new HashSet<String>();
for (int i = 0; i < 28; i++) {
String id = "example-" + i;
String text = "Text of example" + i + ".";
trainingSet.registerExample(id, text, Arrays.asList(TOPIC_1));
expectedCollectedIds.add(id);
expectedCollectedText.add(text);
}
trainingSet.setBatchSize(10);
Batch<Example> examples = trainingSet.getPositiveExamples(Arrays.asList(TOPIC_1, TOPIC_2), null);
assertEquals(10, examples.items.size());
for (Example example : examples.items) {
collectedIds.add(example.id);
collectedText.add(example.getContentString());
}
assertTrue(examples.hasMore);
examples = trainingSet.getPositiveExamples(Arrays.asList(TOPIC_1, TOPIC_2), examples.nextOffset);
assertEquals(10, examples.items.size());
for (Example example : examples.items) {
collectedIds.add(example.id);
collectedText.add(example.getContentString());
}
assertTrue(examples.hasMore);
examples = trainingSet.getPositiveExamples(Arrays.asList(TOPIC_1, TOPIC_2), examples.nextOffset);
assertEquals(8, examples.items.size());
for (Example example : examples.items) {
collectedIds.add(example.id);
collectedText.add(example.getContentString());
}
assertFalse(examples.hasMore);
assertEquals(expectedCollectedIds, collectedIds);
assertEquals(expectedCollectedText, collectedText);
}
use of org.apache.stanbol.enhancer.topic.api.training.Example in project stanbol by apache.
the class TrainingSetTest method testBatchingNegativeExamplesAndAutoId.
@Test
public void testBatchingNegativeExamplesAndAutoId() throws ConfigurationException, TrainingSetException {
log.info(" --- testBatchingNegativeExamplesAndAutoId --- ");
Set<String> expectedCollectedIds = new HashSet<String>();
Set<String> expectedCollectedText = new HashSet<String>();
Set<String> collectedIds = new HashSet<String>();
Set<String> collectedText = new HashSet<String>();
for (int i = 0; i < 17; i++) {
String text = "Text of example" + i + ".";
String id = trainingSet.registerExample(null, text, Arrays.asList(TOPIC_1));
expectedCollectedIds.add(id);
expectedCollectedText.add(text);
}
trainingSet.setBatchSize(10);
Batch<Example> examples = trainingSet.getNegativeExamples(Arrays.asList(TOPIC_2), null);
assertEquals(10, examples.items.size());
for (Example example : examples.items) {
collectedIds.add(example.id);
collectedText.add(example.getContentString());
}
assertTrue(examples.hasMore);
examples = trainingSet.getNegativeExamples(Arrays.asList(TOPIC_2), examples.nextOffset);
assertEquals(7, examples.items.size());
for (Example example : examples.items) {
collectedIds.add(example.id);
collectedText.add(example.getContentString());
}
assertFalse(examples.hasMore);
assertEquals(expectedCollectedIds, collectedIds);
assertEquals(expectedCollectedText, collectedText);
}
use of org.apache.stanbol.enhancer.topic.api.training.Example in project stanbol by apache.
the class TopicClassificationEngine method updateTopic.
/**
* @param conceptUri
* the topic model to update
* @param metadataEntryId
* of the metadata entry id of the topic
* @param modelEntryId
* of the model entry id of the topic
* @param impactedTopics
* the list of impacted topics (e.g. the topic node and direct children)
* @param primaryTopicUri
* @param broaderConcepts
* the collection of broader to re-add in the broader field
*/
protected void updateTopic(String conceptUri, String metadataId, String modelId, List<String> impactedTopics, String primaryTopicUri, Collection<Object> broaderConcepts) throws TrainingSetException, ClassifierException {
long start = System.currentTimeMillis();
Batch<Example> examples = Batch.emtpyBatch(Example.class);
StringBuffer sb = new StringBuffer();
int offset = 0;
do {
examples = getTrainingSet().getPositiveExamples(impactedTopics, examples.nextOffset);
for (Example example : examples.items) {
if ((cvFoldCount != 0) && (offset % cvFoldCount == cvFoldIndex)) {
// we are performing a cross validation session and this example belong to the test
// fold hence should be skipped
offset++;
continue;
}
offset++;
sb.append(StringUtils.join(example.contents, "\n\n"));
sb.append("\n\n");
}
} while (sb.length() < MAX_CHARS_PER_TOPIC && examples.hasMore);
// reindex the topic with the new text data collected from the examples
SolrInputDocument modelEntry = new SolrInputDocument();
modelEntry.addField(entryIdField, modelId);
modelEntry.addField(conceptUriField, conceptUri);
modelEntry.addField(entryTypeField, MODEL_ENTRY);
if (sb.length() > 0) {
modelEntry.addField(similarityField, sb);
}
// update the metadata of the topic model
SolrInputDocument metadataEntry = new SolrInputDocument();
metadataEntry.addField(entryIdField, metadataId);
metadataEntry.addField(modelEntryIdField, modelId);
metadataEntry.addField(entryTypeField, METADATA_ENTRY);
metadataEntry.addField(conceptUriField, conceptUri);
if (primaryTopicUriField != null) {
metadataEntry.addField(primaryTopicUriField, primaryTopicUri);
}
if (broaderConcepts != null && broaderField != null) {
metadataEntry.addField(broaderField, broaderConcepts);
}
if (modelUpdateDateField != null) {
metadataEntry.addField(modelUpdateDateField, UTCTimeStamper.nowUtcDate());
}
SolrServer solrServer = getActiveSolrServer();
try {
UpdateRequest request = new UpdateRequest();
request.add(metadataEntry);
request.add(modelEntry);
solrServer.request(request);
// the commit is done by the caller in batch
} catch (Exception e) {
String msg = String.format("Error updating topic with id '%s' on Solr Core '%s'", conceptUri, solrCoreId);
throw new ClassifierException(msg, e);
}
long stop = System.currentTimeMillis();
log.debug("Sucessfully updated topic {} in {}s", conceptUri, (double) (stop - start) / 1000.);
}
Aggregations