Search in sources :

Example 1 with TrainingCapability

use of de.tudarmstadt.ukp.inception.recommendation.api.recommender.TrainingCapability in project inception by inception-project.

the class TrainingTask method execute.

@Override
public void execute() {
    try (CasStorageSession session = CasStorageSession.open()) {
        Project project = getProject();
        User user = getUser().orElseThrow();
        log.debug("[{}][{}]: Starting training for project {} triggered by [{}]...", getId(), user.getUsername(), project, getTrigger());
        logMessages.add(info(this, "Starting training triggered by [%s]...", getTrigger()));
        // Read the CASes only when they are accessed the first time. This allows us to skip
        // reading the CASes in case that no layer / recommender is available or if no
        // recommender requires evaluation.
        LazyInitializer<List<TrainingDocument>> casses = new LazyInitializer<List<TrainingDocument>>() {

            @Override
            protected List<TrainingDocument> initialize() {
                return readCasses(project, user);
            }
        };
        boolean seenSuccessfulTraining = false;
        boolean seenNonTrainingRecommender = false;
        for (AnnotationLayer layer : annoService.listAnnotationLayer(project)) {
            if (!layer.isEnabled()) {
                continue;
            }
            List<EvaluatedRecommender> recommenders = recommendationService.getActiveRecommenders(user, layer);
            if (recommenders.isEmpty()) {
                log.trace("[{}][{}][{}]: No active recommenders, skipping training.", getId(), user.getUsername(), layer.getUiName());
                logMessages.add(info(this, "No active recommenders for layer [%s], skipping training.", layer.getUiName()));
                continue;
            }
            for (EvaluatedRecommender r : recommenders) {
                // Make sure we have the latest recommender config from the DB - the one from
                // the active recommenders list may be outdated
                Recommender recommender;
                try {
                    recommender = recommendationService.getRecommender(r.getRecommender().getId());
                } catch (NoResultException e) {
                    log.debug("[{}][{}][{}]: Recommender no longer available... skipping", getId(), user.getUsername(), r.getRecommender().getName());
                    continue;
                }
                if (!recommender.isEnabled()) {
                    log.debug("[{}][{}][{}]: Disabled - skipping", user.getUsername(), getId(), r.getRecommender().getName());
                    continue;
                }
                long startTime = System.currentTimeMillis();
                try {
                    Optional<RecommendationEngineFactory<?>> maybeFactory = recommendationService.getRecommenderFactory(recommender);
                    if (maybeFactory.isEmpty()) {
                        log.warn("[{}][{}]: No factory found - skipping recommender", user.getUsername(), r.getRecommender().getName());
                        continue;
                    }
                    RecommendationEngineFactory<?> factory = maybeFactory.get();
                    if (!factory.accepts(recommender.getLayer(), recommender.getFeature())) {
                        log.debug("[{}][{}][{}]: Recommender configured with invalid layer or " + "feature - skipping recommender", getId(), user.getUsername(), r.getRecommender().getName());
                        logMessages.add(error(this, "Recommender [%s] configured with invalid layer or feature - skipping recommender.", r.getRecommender().getName()));
                        appEventPublisher.publishEvent(new RecommenderTaskEvent(this, user.getUsername(), "Recommender configured with invalid layer or feature - skipping training recommender.", recommender));
                        continue;
                    }
                    RecommendationEngine recommendationEngine = factory.build(recommender);
                    RecommenderContext ctx = recommendationEngine.newContext(recommendationService.getContext(user, recommender).orElse(RecommenderContext.EMPTY_CONTEXT));
                    ctx.setUser(user);
                    TrainingCapability capability = recommendationEngine.getTrainingCapability();
                    // prediction
                    if (capability == TRAINING_NOT_SUPPORTED) {
                        seenNonTrainingRecommender = true;
                        log.debug("[{}][{}][{}]: Engine does not support training", getId(), user.getUsername(), recommender.getName());
                        ctx.close();
                        recommendationService.putContext(user, recommender, ctx);
                        continue;
                    }
                    List<CAS> cassesForTraining = // 
                    casses.get().stream().filter(e -> !recommender.getStatesIgnoredForTraining().contains(e.state)).filter(e -> containsTargetTypeAndFeature(recommender, e.cas)).map(e -> e.cas).collect(toList());
                    // do not mark as ready
                    if (cassesForTraining.isEmpty() && capability == TRAINING_REQUIRED) {
                        log.debug("[{}][{}][{}]: There are no annotations available to train on", getId(), user.getUsername(), recommender.getName());
                        logMessages.add(warn(this, "There are no [%s] annotations available to train on.", layer.getUiName()));
                        // This can happen if there were already predictions based on existing
                        // annotations, but all annotations have been removed/deleted. To ensure
                        // that the prediction run removes the stale predictions, we need to
                        // call it a success here.
                        seenSuccessfulTraining = true;
                        continue;
                    }
                    log.debug("[{}][{}][{}]: Training model on [{}] out of [{}] documents ...", getId(), user.getUsername(), recommender.getName(), cassesForTraining.size(), casses.get().size());
                    logMessages.add(info(this, "Training model for [%s] on [%d] out of [%d] documents ...", layer.getUiName(), cassesForTraining.size(), casses.get().size()));
                    recommendationEngine.train(ctx, cassesForTraining);
                    logMessages.addAll(ctx.getMessages());
                    long duration = System.currentTimeMillis() - startTime;
                    if (!recommendationEngine.isReadyForPrediction(ctx)) {
                        int docNum = casses.get().size();
                        int trainDocNum = cassesForTraining.size();
                        log.debug("[{}][{}][{}]: Training on [{}] out of [{}] documents not successful ({} ms)", getId(), user.getUsername(), recommender.getName(), trainDocNum, docNum, duration);
                        logMessages.add(info(this, "Training not successful (%d ms).", duration));
                        appEventPublisher.publishEvent(new RecommenderTaskEvent(this, user.getUsername(), format("Training on %d out of %d documents not successful (%d ms)", trainDocNum, docNum, duration), recommender));
                        continue;
                    }
                    log.debug("[{}][{}][{}]: Training successful on [{}] out of [{}] documents ({} ms)", getId(), user.getUsername(), recommender.getName(), cassesForTraining.size(), casses.get().size(), duration);
                    logMessages.add(info(this, "Training successful on [%d] out of [%d] documents (%d ms)", cassesForTraining.size(), casses.get().size(), duration));
                    seenSuccessfulTraining = true;
                    ctx.close();
                    recommendationService.putContext(user, recommender, ctx);
                }// even if a particular recommender fails.
                 catch (Throwable e) {
                    long duration = System.currentTimeMillis() - startTime;
                    log.error("[{}][{}][{}]: Training failed ({} ms)", getId(), user.getUsername(), recommender.getName(), (System.currentTimeMillis() - startTime), e);
                    logMessages.add(error(this, "Training failed (%d ms): %s", duration, getRootCauseMessage(e)));
                    appEventPublisher.publishEvent(new RecommenderTaskEvent(this, user.getUsername(), String.format("Training failed (%d ms) with %s", duration, e.getMessage()), recommender));
                }
            }
        }
        if (!seenSuccessfulTraining && !seenNonTrainingRecommender) {
            log.debug("[{}][{}]: No recommenders trained successfully and no non-training " + "recommenders, skipping prediction.", getId(), user.getUsername());
            return;
        }
        PredictionTask predictionTask = new PredictionTask(user, String.format("TrainingTask %s complete", getId()), currentDocument);
        predictionTask.inheritLog(logMessages);
        schedulingService.enqueue(predictionTask);
    }
}
Also used : LogMessage.error(de.tudarmstadt.ukp.clarin.webanno.support.logging.LogMessage.error) TrainingCapability(de.tudarmstadt.ukp.inception.recommendation.api.recommender.TrainingCapability) LogMessage(de.tudarmstadt.ukp.clarin.webanno.support.logging.LogMessage) RecommendationService(de.tudarmstadt.ukp.inception.recommendation.api.RecommendationService) LoggerFactory(org.slf4j.LoggerFactory) NoResultException(javax.persistence.NoResultException) CAS(org.apache.uima.cas.CAS) Autowired(org.springframework.beans.factory.annotation.Autowired) RecommendationEngine(de.tudarmstadt.ukp.inception.recommendation.api.recommender.RecommendationEngine) Task(de.tudarmstadt.ukp.inception.scheduling.Task) ArrayList(java.util.ArrayList) Type(org.apache.uima.cas.Type) AnnotationSchemaService(de.tudarmstadt.ukp.clarin.webanno.api.AnnotationSchemaService) User(de.tudarmstadt.ukp.clarin.webanno.security.model.User) EvaluatedRecommender(de.tudarmstadt.ukp.inception.recommendation.api.model.EvaluatedRecommender) Map(java.util.Map) SchedulingService(de.tudarmstadt.ukp.inception.scheduling.SchedulingService) ApplicationEventPublisher(org.springframework.context.ApplicationEventPublisher) RecommenderContext(de.tudarmstadt.ukp.inception.recommendation.api.recommender.RecommenderContext) Project(de.tudarmstadt.ukp.clarin.webanno.model.Project) SHARED_READ_ONLY_ACCESS(de.tudarmstadt.ukp.clarin.webanno.api.casstorage.CasAccessMode.SHARED_READ_ONLY_ACCESS) LogMessage.warn(de.tudarmstadt.ukp.clarin.webanno.support.logging.LogMessage.warn) CasStorageSession(de.tudarmstadt.ukp.clarin.webanno.api.dao.casstorage.CasStorageSession) DocumentService(de.tudarmstadt.ukp.clarin.webanno.api.DocumentService) Logger(org.slf4j.Logger) TRAINING_REQUIRED(de.tudarmstadt.ukp.inception.recommendation.api.recommender.TrainingCapability.TRAINING_REQUIRED) ExceptionUtils.getRootCauseMessage(org.apache.commons.lang3.exception.ExceptionUtils.getRootCauseMessage) IOException(java.io.IOException) AnnotationDocument(de.tudarmstadt.ukp.clarin.webanno.model.AnnotationDocument) Recommender(de.tudarmstadt.ukp.inception.recommendation.api.model.Recommender) RecommendationEngineFactory(de.tudarmstadt.ukp.inception.recommendation.api.recommender.RecommendationEngineFactory) AnnotationDocumentState(de.tudarmstadt.ukp.clarin.webanno.model.AnnotationDocumentState) RecommenderTaskEvent(de.tudarmstadt.ukp.inception.recommendation.event.RecommenderTaskEvent) String.format(java.lang.String.format) LazyInitializer(org.apache.commons.lang3.concurrent.LazyInitializer) AUTO_CAS_UPGRADE(de.tudarmstadt.ukp.clarin.webanno.api.CasUpgradeMode.AUTO_CAS_UPGRADE) LogMessage.info(de.tudarmstadt.ukp.clarin.webanno.support.logging.LogMessage.info) CasUtil(org.apache.uima.fit.util.CasUtil) Collectors.toList(java.util.stream.Collectors.toList) List(java.util.List) TRAINING_NOT_SUPPORTED(de.tudarmstadt.ukp.inception.recommendation.api.recommender.TrainingCapability.TRAINING_NOT_SUPPORTED) AnnotationLayer(de.tudarmstadt.ukp.clarin.webanno.model.AnnotationLayer) SourceDocument(de.tudarmstadt.ukp.clarin.webanno.model.SourceDocument) Optional(java.util.Optional) User(de.tudarmstadt.ukp.clarin.webanno.security.model.User) EvaluatedRecommender(de.tudarmstadt.ukp.inception.recommendation.api.model.EvaluatedRecommender) NoResultException(javax.persistence.NoResultException) AnnotationLayer(de.tudarmstadt.ukp.clarin.webanno.model.AnnotationLayer) EvaluatedRecommender(de.tudarmstadt.ukp.inception.recommendation.api.model.EvaluatedRecommender) Recommender(de.tudarmstadt.ukp.inception.recommendation.api.model.Recommender) ArrayList(java.util.ArrayList) Collectors.toList(java.util.stream.Collectors.toList) List(java.util.List) RecommenderTaskEvent(de.tudarmstadt.ukp.inception.recommendation.event.RecommenderTaskEvent) CasStorageSession(de.tudarmstadt.ukp.clarin.webanno.api.dao.casstorage.CasStorageSession) RecommendationEngine(de.tudarmstadt.ukp.inception.recommendation.api.recommender.RecommendationEngine) LazyInitializer(org.apache.commons.lang3.concurrent.LazyInitializer) RecommenderContext(de.tudarmstadt.ukp.inception.recommendation.api.recommender.RecommenderContext) Project(de.tudarmstadt.ukp.clarin.webanno.model.Project) RecommendationEngineFactory(de.tudarmstadt.ukp.inception.recommendation.api.recommender.RecommendationEngineFactory) CAS(org.apache.uima.cas.CAS) TrainingCapability(de.tudarmstadt.ukp.inception.recommendation.api.recommender.TrainingCapability)

Aggregations

AnnotationSchemaService (de.tudarmstadt.ukp.clarin.webanno.api.AnnotationSchemaService)1 AUTO_CAS_UPGRADE (de.tudarmstadt.ukp.clarin.webanno.api.CasUpgradeMode.AUTO_CAS_UPGRADE)1 DocumentService (de.tudarmstadt.ukp.clarin.webanno.api.DocumentService)1 SHARED_READ_ONLY_ACCESS (de.tudarmstadt.ukp.clarin.webanno.api.casstorage.CasAccessMode.SHARED_READ_ONLY_ACCESS)1 CasStorageSession (de.tudarmstadt.ukp.clarin.webanno.api.dao.casstorage.CasStorageSession)1 AnnotationDocument (de.tudarmstadt.ukp.clarin.webanno.model.AnnotationDocument)1 AnnotationDocumentState (de.tudarmstadt.ukp.clarin.webanno.model.AnnotationDocumentState)1 AnnotationLayer (de.tudarmstadt.ukp.clarin.webanno.model.AnnotationLayer)1 Project (de.tudarmstadt.ukp.clarin.webanno.model.Project)1 SourceDocument (de.tudarmstadt.ukp.clarin.webanno.model.SourceDocument)1 User (de.tudarmstadt.ukp.clarin.webanno.security.model.User)1 LogMessage (de.tudarmstadt.ukp.clarin.webanno.support.logging.LogMessage)1 LogMessage.error (de.tudarmstadt.ukp.clarin.webanno.support.logging.LogMessage.error)1 LogMessage.info (de.tudarmstadt.ukp.clarin.webanno.support.logging.LogMessage.info)1 LogMessage.warn (de.tudarmstadt.ukp.clarin.webanno.support.logging.LogMessage.warn)1 RecommendationService (de.tudarmstadt.ukp.inception.recommendation.api.RecommendationService)1 EvaluatedRecommender (de.tudarmstadt.ukp.inception.recommendation.api.model.EvaluatedRecommender)1 Recommender (de.tudarmstadt.ukp.inception.recommendation.api.model.Recommender)1 RecommendationEngine (de.tudarmstadt.ukp.inception.recommendation.api.recommender.RecommendationEngine)1 RecommendationEngineFactory (de.tudarmstadt.ukp.inception.recommendation.api.recommender.RecommendationEngineFactory)1