Search in sources :

Example 1 with Example

use of org.apache.stanbol.enhancer.topic.api.training.Example in project stanbol by apache.

the class TopicClassificationEngine method performCVFold.

protected int performCVFold(int cvFoldIndex, int cvFoldCount, int cvIterations, boolean incremental) throws ConfigurationException, TrainingSetException, ClassifierException {
    cvIterations = cvIterations <= 0 ? cvFoldCount : cvFoldCount;
    log.info(String.format("Performing evaluation %d-fold CV iteration %d/%d on classifier %s", cvFoldCount, cvFoldIndex + 1, cvIterations, engineName));
    long start = System.currentTimeMillis();
    final TopicClassificationEngine classifier = new TopicClassificationEngine();
    try {
        if (managedSolrServer != null) {
            // OSGi setup: the evaluation server will be generated automatically using the
            // managedSolrServer
            classifier.bindManagedSolrServer(managedSolrServer);
            classifier.activate(context, getCanonicalConfiguration(// TODO: maybe we should use the SolrCoreName instead
            engineName + "-evaluation", solrCoreConfig));
        } else {
            if (__evaluationServer == null) {
                __evaluationServerDir = new File(embeddedSolrServerDir, engineName + "-evaluation");
                if (!__evaluationServerDir.exists()) {
                    FileUtils.forceMkdir(__evaluationServerDir);
                }
                __evaluationServer = EmbeddedSolrHelper.makeEmbeddedSolrServer(__evaluationServerDir, "evaluationclassifierserver", "default-topic-model", "default-topic-model");
            }
            classifier.configure(getCanonicalConfiguration(__evaluationServer, solrCoreConfig));
        }
    } catch (Exception e) {
        throw new ClassifierException(e);
    }
    // clean all previous concepts from the evaluation classifier in case we are reusing an existing solr
    // index from OSGi.
    classifier.removeAllConcepts();
    // iterate over all the topics to register them in the evaluation classifier
    batchOverTopics(new BatchProcessor<SolrDocument>() {

        @Override
        public int process(List<SolrDocument> batch) throws ClassifierException {
            for (SolrDocument topicEntry : batch) {
                String conceptId = topicEntry.getFirstValue(conceptUriField).toString();
                Collection<Object> broader = topicEntry.getFieldValues(broaderField);
                if (broader == null) {
                    classifier.addConcept(conceptId, null, null);
                } else {
                    List<String> broaderConcepts = new ArrayList<String>();
                    for (Object broaderConcept : broader) {
                        broaderConcepts.add(broaderConcept.toString());
                    }
                    classifier.addConcept(conceptId, null, broaderConcepts);
                }
            }
            return batch.size();
        }
    });
    // build the model on the for the current train CV folds
    classifier.setCrossValidationInfo(cvFoldIndex, cvFoldCount);
    // bind our new classifier to the same training set at the parent
    classifier.setTrainingSet(getTrainingSet());
    classifier.updateModel(false);
    final int foldCount = cvFoldCount;
    final int foldIndex = cvFoldIndex;
    // iterate over the topics again to compute scores on the test fold
    int updatedTopics = batchOverTopics(new BatchProcessor<SolrDocument>() {

        @Override
        public int process(List<SolrDocument> batch) throws TrainingSetException, ClassifierException {
            int offset;
            int updated = 0;
            for (SolrDocument topicMetadata : batch) {
                String topic = topicMetadata.getFirstValue(conceptUriField).toString();
                List<String> topics = Arrays.asList(topic);
                List<String> falseNegativeExamples = new ArrayList<String>();
                int truePositives = 0;
                int falseNegatives = 0;
                int positiveSupport = 0;
                offset = 0;
                Batch<Example> examples = Batch.emtpyBatch(Example.class);
                boolean skipTopic = false;
                do {
                    examples = getTrainingSet().getPositiveExamples(topics, examples.nextOffset);
                    if (offset == 0 && examples.items.size() < MIN_EVALUATION_SAMPLES) {
                        // we need a minimum about of examples otherwise it's really not
                        // worth computing statistics
                        skipTopic = true;
                        break;
                    }
                    for (Example example : examples.items) {
                        if (!(offset % foldCount == foldIndex)) {
                            // this example is not part of the test fold, skip it
                            offset++;
                            continue;
                        }
                        positiveSupport++;
                        offset++;
                        List<TopicSuggestion> suggestedTopics = classifier.suggestTopics(example.contents);
                        boolean match = false;
                        for (TopicSuggestion suggestedTopic : suggestedTopics) {
                            if (topic.equals(suggestedTopic.conceptUri)) {
                                match = true;
                                truePositives++;
                                break;
                            }
                        }
                        if (!match) {
                            falseNegatives++;
                            if (falseNegativeExamples.size() < MAX_COLLECTED_EXAMPLES / foldCount) {
                                falseNegativeExamples.add(example.id);
                            }
                        }
                    }
                } while (!skipTopic && examples.hasMore && offset < MAX_EVALUATION_SAMPLES);
                List<String> falsePositiveExamples = new ArrayList<String>();
                int falsePositives = 0;
                int negativeSupport = 0;
                offset = 0;
                examples = Batch.emtpyBatch(Example.class);
                do {
                    if (skipTopic) {
                        break;
                    }
                    examples = getTrainingSet().getNegativeExamples(topics, examples.nextOffset);
                    for (Example example : examples.items) {
                        if (!(offset % foldCount == foldIndex)) {
                            // this example is not part of the test fold, skip it
                            offset++;
                            continue;
                        }
                        negativeSupport++;
                        offset++;
                        List<TopicSuggestion> suggestedTopics = classifier.suggestTopics(example.contents);
                        for (TopicSuggestion suggestedTopic : suggestedTopics) {
                            if (topic.equals(suggestedTopic.conceptUri)) {
                                falsePositives++;
                                if (falsePositiveExamples.size() < MAX_COLLECTED_EXAMPLES / foldCount) {
                                    falsePositiveExamples.add(example.id);
                                }
                                break;
                            }
                        }
                    // we don't need to collect true negatives
                    }
                } while (examples.hasMore && offset < MAX_EVALUATION_SAMPLES);
                if (skipTopic) {
                    log.debug("Skipping evaluation of {} because too few positive examples.", topic);
                } else {
                    // compute precision, recall and f1 score for the current test fold and topic
                    float precision = 0;
                    if (truePositives != 0 || falsePositives != 0) {
                        precision = truePositives / (float) (truePositives + falsePositives);
                    }
                    float recall = 0;
                    if (truePositives != 0 || falseNegatives != 0) {
                        recall = truePositives / (float) (truePositives + falseNegatives);
                    }
                    updatePerformanceMetadata(topic, precision, recall, positiveSupport, negativeSupport, falsePositiveExamples, falseNegativeExamples);
                    updated += 1;
                }
            }
            try {
                getActiveSolrServer().commit();
            } catch (Exception e) {
                throw new ClassifierException(e);
            }
            return updated;
        }
    });
    long stop = System.currentTimeMillis();
    log.info(String.format("Finished CV iteration %d/%d on classifier %s in %fs.", cvFoldIndex + 1, cvFoldCount, engineName, (stop - start) / 1000.0));
    if (context != null) {
        // close open trackers
        classifier.deactivate(context);
    }
    return updatedTopics;
}
Also used : TopicSuggestion(org.apache.stanbol.enhancer.topic.api.TopicSuggestion) EngineException(org.apache.stanbol.enhancer.servicesapi.EngineException) SolrServerException(org.apache.solr.client.solrj.SolrServerException) ConfigurationException(org.osgi.service.cm.ConfigurationException) InvalidSyntaxException(org.osgi.framework.InvalidSyntaxException) TrainingSetException(org.apache.stanbol.enhancer.topic.api.training.TrainingSetException) ClassifierException(org.apache.stanbol.enhancer.topic.api.ClassifierException) InvalidContentException(org.apache.stanbol.enhancer.servicesapi.InvalidContentException) EntityhubException(org.apache.stanbol.entityhub.servicesapi.EntityhubException) ChainException(org.apache.stanbol.enhancer.servicesapi.ChainException) IOException(java.io.IOException) SolrDocument(org.apache.solr.common.SolrDocument) Batch(org.apache.stanbol.enhancer.topic.api.Batch) Example(org.apache.stanbol.enhancer.topic.api.training.Example) Collection(java.util.Collection) SolrDocumentList(org.apache.solr.common.SolrDocumentList) List(java.util.List) ArrayList(java.util.ArrayList) File(java.io.File) TrainingSetException(org.apache.stanbol.enhancer.topic.api.training.TrainingSetException) ClassifierException(org.apache.stanbol.enhancer.topic.api.ClassifierException)

Example 2 with Example

use of org.apache.stanbol.enhancer.topic.api.training.Example in project stanbol by apache.

the class SolrTrainingSet method getExamples.

protected Batch<Example> getExamples(List<String> topics, Object offset, boolean positive) throws TrainingSetException {
    List<Example> items = new ArrayList<Example>();
    SolrServer solrServer = getActiveSolrServer();
    SolrQuery query = new SolrQuery();
    List<String> parts = new ArrayList<String>();
    String q = "";
    if (topics.isEmpty()) {
        q += "*:*";
    } else if (positive) {
        for (String topic : topics) {
            parts.add(topicUrisField + ":" + ClientUtils.escapeQueryChars(topic));
        }
        if (offset != null) {
            q += "(";
        }
        q += StringUtils.join(parts, " OR ");
        if (offset != null) {
            q += ")";
        }
    } else {
        for (String topic : topics) {
            parts.add("-" + topicUrisField + ":" + ClientUtils.escapeQueryChars(topic));
        }
        q += StringUtils.join(parts, " AND ");
    }
    if (offset != null) {
        q += " AND " + exampleIdField + ":[" + offset.toString() + " TO *]";
    }
    query.setQuery(q);
    query.addSortField(exampleIdField, SolrQuery.ORDER.asc);
    query.set("rows", batchSize + 1);
    String nextExampleId = null;
    try {
        int count = 0;
        QueryResponse response = solrServer.query(query);
        for (SolrDocument result : response.getResults()) {
            if (count == batchSize) {
                nextExampleId = result.getFirstValue(exampleIdField).toString();
            } else {
                count++;
                String exampleId = result.getFirstValue(exampleIdField).toString();
                Collection<Object> labelValues = result.getFieldValues(topicUrisField);
                Collection<Object> textValues = result.getFieldValues(exampleTextField);
                if (textValues == null) {
                    continue;
                }
                items.add(new Example(exampleId, labelValues, textValues));
            }
        }
    } catch (SolrServerException e) {
        String msg = String.format("Error while fetching positive examples for topics ['%s'] on Solr Core '%s'.", StringUtils.join(topics, "', '"), solrCoreId);
        throw new TrainingSetException(msg, e);
    }
    return new Batch<Example>(items, nextExampleId != null, nextExampleId);
}
Also used : SolrServerException(org.apache.solr.client.solrj.SolrServerException) ArrayList(java.util.ArrayList) ManagedSolrServer(org.apache.stanbol.commons.solr.managed.ManagedSolrServer) SolrServer(org.apache.solr.client.solrj.SolrServer) SolrQuery(org.apache.solr.client.solrj.SolrQuery) SolrDocument(org.apache.solr.common.SolrDocument) Batch(org.apache.stanbol.enhancer.topic.api.Batch) Example(org.apache.stanbol.enhancer.topic.api.training.Example) QueryResponse(org.apache.solr.client.solrj.response.QueryResponse) TrainingSetException(org.apache.stanbol.enhancer.topic.api.training.TrainingSetException)

Example 3 with Example

use of org.apache.stanbol.enhancer.topic.api.training.Example in project stanbol by apache.

the class TrainingSetTest method testBatchingPositiveExamples.

@Test
public void testBatchingPositiveExamples() throws ConfigurationException, TrainingSetException {
    log.info(" --- testBatchingPositiveExamples --- ");
    Set<String> expectedCollectedIds = new HashSet<String>();
    Set<String> expectedCollectedText = new HashSet<String>();
    Set<String> collectedIds = new HashSet<String>();
    Set<String> collectedText = new HashSet<String>();
    for (int i = 0; i < 28; i++) {
        String id = "example-" + i;
        String text = "Text of example" + i + ".";
        trainingSet.registerExample(id, text, Arrays.asList(TOPIC_1));
        expectedCollectedIds.add(id);
        expectedCollectedText.add(text);
    }
    trainingSet.setBatchSize(10);
    Batch<Example> examples = trainingSet.getPositiveExamples(Arrays.asList(TOPIC_1, TOPIC_2), null);
    assertEquals(10, examples.items.size());
    for (Example example : examples.items) {
        collectedIds.add(example.id);
        collectedText.add(example.getContentString());
    }
    assertTrue(examples.hasMore);
    examples = trainingSet.getPositiveExamples(Arrays.asList(TOPIC_1, TOPIC_2), examples.nextOffset);
    assertEquals(10, examples.items.size());
    for (Example example : examples.items) {
        collectedIds.add(example.id);
        collectedText.add(example.getContentString());
    }
    assertTrue(examples.hasMore);
    examples = trainingSet.getPositiveExamples(Arrays.asList(TOPIC_1, TOPIC_2), examples.nextOffset);
    assertEquals(8, examples.items.size());
    for (Example example : examples.items) {
        collectedIds.add(example.id);
        collectedText.add(example.getContentString());
    }
    assertFalse(examples.hasMore);
    assertEquals(expectedCollectedIds, collectedIds);
    assertEquals(expectedCollectedText, collectedText);
}
Also used : Example(org.apache.stanbol.enhancer.topic.api.training.Example) HashSet(java.util.HashSet) Test(org.junit.Test)

Example 4 with Example

use of org.apache.stanbol.enhancer.topic.api.training.Example in project stanbol by apache.

the class TrainingSetTest method testBatchingNegativeExamplesAndAutoId.

@Test
public void testBatchingNegativeExamplesAndAutoId() throws ConfigurationException, TrainingSetException {
    log.info(" --- testBatchingNegativeExamplesAndAutoId --- ");
    Set<String> expectedCollectedIds = new HashSet<String>();
    Set<String> expectedCollectedText = new HashSet<String>();
    Set<String> collectedIds = new HashSet<String>();
    Set<String> collectedText = new HashSet<String>();
    for (int i = 0; i < 17; i++) {
        String text = "Text of example" + i + ".";
        String id = trainingSet.registerExample(null, text, Arrays.asList(TOPIC_1));
        expectedCollectedIds.add(id);
        expectedCollectedText.add(text);
    }
    trainingSet.setBatchSize(10);
    Batch<Example> examples = trainingSet.getNegativeExamples(Arrays.asList(TOPIC_2), null);
    assertEquals(10, examples.items.size());
    for (Example example : examples.items) {
        collectedIds.add(example.id);
        collectedText.add(example.getContentString());
    }
    assertTrue(examples.hasMore);
    examples = trainingSet.getNegativeExamples(Arrays.asList(TOPIC_2), examples.nextOffset);
    assertEquals(7, examples.items.size());
    for (Example example : examples.items) {
        collectedIds.add(example.id);
        collectedText.add(example.getContentString());
    }
    assertFalse(examples.hasMore);
    assertEquals(expectedCollectedIds, collectedIds);
    assertEquals(expectedCollectedText, collectedText);
}
Also used : Example(org.apache.stanbol.enhancer.topic.api.training.Example) HashSet(java.util.HashSet) Test(org.junit.Test)

Example 5 with Example

use of org.apache.stanbol.enhancer.topic.api.training.Example in project stanbol by apache.

the class TopicClassificationEngine method updateTopic.

/**
 * @param conceptUri
 *            the topic model to update
 * @param metadataEntryId
 *            of the metadata entry id of the topic
 * @param modelEntryId
 *            of the model entry id of the topic
 * @param impactedTopics
 *            the list of impacted topics (e.g. the topic node and direct children)
 * @param primaryTopicUri
 * @param broaderConcepts
 *            the collection of broader to re-add in the broader field
 */
protected void updateTopic(String conceptUri, String metadataId, String modelId, List<String> impactedTopics, String primaryTopicUri, Collection<Object> broaderConcepts) throws TrainingSetException, ClassifierException {
    long start = System.currentTimeMillis();
    Batch<Example> examples = Batch.emtpyBatch(Example.class);
    StringBuffer sb = new StringBuffer();
    int offset = 0;
    do {
        examples = getTrainingSet().getPositiveExamples(impactedTopics, examples.nextOffset);
        for (Example example : examples.items) {
            if ((cvFoldCount != 0) && (offset % cvFoldCount == cvFoldIndex)) {
                // we are performing a cross validation session and this example belong to the test
                // fold hence should be skipped
                offset++;
                continue;
            }
            offset++;
            sb.append(StringUtils.join(example.contents, "\n\n"));
            sb.append("\n\n");
        }
    } while (sb.length() < MAX_CHARS_PER_TOPIC && examples.hasMore);
    // reindex the topic with the new text data collected from the examples
    SolrInputDocument modelEntry = new SolrInputDocument();
    modelEntry.addField(entryIdField, modelId);
    modelEntry.addField(conceptUriField, conceptUri);
    modelEntry.addField(entryTypeField, MODEL_ENTRY);
    if (sb.length() > 0) {
        modelEntry.addField(similarityField, sb);
    }
    // update the metadata of the topic model
    SolrInputDocument metadataEntry = new SolrInputDocument();
    metadataEntry.addField(entryIdField, metadataId);
    metadataEntry.addField(modelEntryIdField, modelId);
    metadataEntry.addField(entryTypeField, METADATA_ENTRY);
    metadataEntry.addField(conceptUriField, conceptUri);
    if (primaryTopicUriField != null) {
        metadataEntry.addField(primaryTopicUriField, primaryTopicUri);
    }
    if (broaderConcepts != null && broaderField != null) {
        metadataEntry.addField(broaderField, broaderConcepts);
    }
    if (modelUpdateDateField != null) {
        metadataEntry.addField(modelUpdateDateField, UTCTimeStamper.nowUtcDate());
    }
    SolrServer solrServer = getActiveSolrServer();
    try {
        UpdateRequest request = new UpdateRequest();
        request.add(metadataEntry);
        request.add(modelEntry);
        solrServer.request(request);
    // the commit is done by the caller in batch
    } catch (Exception e) {
        String msg = String.format("Error updating topic with id '%s' on Solr Core '%s'", conceptUri, solrCoreId);
        throw new ClassifierException(msg, e);
    }
    long stop = System.currentTimeMillis();
    log.debug("Sucessfully updated topic {} in {}s", conceptUri, (double) (stop - start) / 1000.);
}
Also used : SolrInputDocument(org.apache.solr.common.SolrInputDocument) UpdateRequest(org.apache.solr.client.solrj.request.UpdateRequest) Example(org.apache.stanbol.enhancer.topic.api.training.Example) EmbeddedSolrServer(org.apache.solr.client.solrj.embedded.EmbeddedSolrServer) SolrServer(org.apache.solr.client.solrj.SolrServer) ManagedSolrServer(org.apache.stanbol.commons.solr.managed.ManagedSolrServer) EngineException(org.apache.stanbol.enhancer.servicesapi.EngineException) SolrServerException(org.apache.solr.client.solrj.SolrServerException) ConfigurationException(org.osgi.service.cm.ConfigurationException) InvalidSyntaxException(org.osgi.framework.InvalidSyntaxException) TrainingSetException(org.apache.stanbol.enhancer.topic.api.training.TrainingSetException) ClassifierException(org.apache.stanbol.enhancer.topic.api.ClassifierException) InvalidContentException(org.apache.stanbol.enhancer.servicesapi.InvalidContentException) EntityhubException(org.apache.stanbol.entityhub.servicesapi.EntityhubException) ChainException(org.apache.stanbol.enhancer.servicesapi.ChainException) IOException(java.io.IOException) ClassifierException(org.apache.stanbol.enhancer.topic.api.ClassifierException)

Aggregations

Example (org.apache.stanbol.enhancer.topic.api.training.Example)5 SolrServerException (org.apache.solr.client.solrj.SolrServerException)3 TrainingSetException (org.apache.stanbol.enhancer.topic.api.training.TrainingSetException)3 IOException (java.io.IOException)2 ArrayList (java.util.ArrayList)2 HashSet (java.util.HashSet)2 SolrServer (org.apache.solr.client.solrj.SolrServer)2 SolrDocument (org.apache.solr.common.SolrDocument)2 ManagedSolrServer (org.apache.stanbol.commons.solr.managed.ManagedSolrServer)2 ChainException (org.apache.stanbol.enhancer.servicesapi.ChainException)2 EngineException (org.apache.stanbol.enhancer.servicesapi.EngineException)2 InvalidContentException (org.apache.stanbol.enhancer.servicesapi.InvalidContentException)2 Batch (org.apache.stanbol.enhancer.topic.api.Batch)2 ClassifierException (org.apache.stanbol.enhancer.topic.api.ClassifierException)2 EntityhubException (org.apache.stanbol.entityhub.servicesapi.EntityhubException)2 Test (org.junit.Test)2 InvalidSyntaxException (org.osgi.framework.InvalidSyntaxException)2 ConfigurationException (org.osgi.service.cm.ConfigurationException)2 File (java.io.File)1 Collection (java.util.Collection)1