Search in sources :

Example 6 with TaskContext

use of org.dkpro.lab.engine.TaskContext in project dkpro-lab by dkpro.

the class PosExampleMaxEnt method run.

@Test
public void run() throws Exception {
    // Route logging through log4j
    System.setProperty("org.apache.uima.logger.class", "org.apache.uima.util.impl.Log4jLogger_impl");
    clean();
    Task preprocessingTask = new UimaTaskBase() {

        @Discriminator
        String corpusPath;

        {
            setType("Preprocessing");
        }

        @Override
        public CollectionReaderDescription getCollectionReaderDescription(TaskContext aContext) throws ResourceInitializationException, IOException {
            return createReader(NegraExportReader.class, NegraExportReader.PARAM_SOURCE_LOCATION, corpusPath, NegraExportReader.PARAM_LANGUAGE, "de");
        }

        @Override
        public AnalysisEngineDescription getAnalysisEngineDescription(TaskContext aContext) throws ResourceInitializationException, IOException {
            File xmiDir = aContext.getFolder("XMI", AccessMode.READWRITE);
            return createEngine(createEngine(SnowballStemmer.class), createEngine(XmiWriter.class, XmiWriter.PARAM_TARGET_LOCATION, xmiDir.getAbsolutePath(), XmiWriter.PARAM_COMPRESSION, CompressionMethod.GZIP));
        }
    };
    Task featureExtractionTask = new UimaTaskBase() {

        {
            setType("FeatureExtraction");
        }

        @Override
        public CollectionReaderDescription getCollectionReaderDescription(TaskContext aContext) throws ResourceInitializationException, IOException {
            File xmiDir = aContext.getFolder("XMI", AccessMode.READONLY);
            return createReader(XmiReader.class, XmiReader.PARAM_SOURCE_LOCATION, xmiDir.getAbsolutePath(), XmiReader.PARAM_PATTERNS, new String[] { "[+]**/*.xmi.gz" });
        }

        @Override
        public AnalysisEngineDescription getAnalysisEngineDescription(TaskContext aContext) throws ResourceInitializationException, IOException {
            File modelDir = aContext.getFolder("MODEL", AccessMode.READWRITE);
            return createEngine(createEngineDescription(ExamplePosAnnotator.class, ExamplePosAnnotator.PARAM_DATA_WRITER_FACTORY_CLASS_NAME, ViterbiDataWriterFactory.class.getName(), ViterbiDataWriterFactory.PARAM_OUTPUT_DIRECTORY, modelDir.getAbsolutePath(), ViterbiDataWriterFactory.PARAM_DELEGATED_DATA_WRITER_FACTORY_CLASS, DefaultMaxentDataWriterFactory.class.getName()));
        }
    };
    Task trainingTask = new ExecutableTaskBase() {

        @Discriminator
        private int iterations;

        @Discriminator
        private int cutoff;

        {
            setType("TrainingTask");
        }

        @Override
        public void execute(TaskContext aContext) throws Exception {
            File dir = aContext.getFolder("MODEL", AccessMode.READWRITE);
            JarClassifierBuilder<?> classifierBuilder = JarClassifierBuilder.fromTrainingDirectory(dir);
            classifierBuilder.trainClassifier(dir, new String[] { String.valueOf(iterations), String.valueOf(cutoff) });
            classifierBuilder.packageClassifier(dir);
        }
    };
    Task analysisTask = new UimaTaskBase() {

        {
            setType("AnalysisTask");
        }

        @Override
        public CollectionReaderDescription getCollectionReaderDescription(TaskContext aContext) throws ResourceInitializationException, IOException {
            return createReaderDescription(TextReader.class, TextReader.PARAM_SOURCE_LOCATION, "src/test/resources/text/**/*.txt", TextReader.PARAM_LANGUAGE, "de");
        }

        @Override
        public AnalysisEngineDescription getAnalysisEngineDescription(TaskContext aContext) throws ResourceInitializationException, IOException {
            File model = new File(aContext.getFolder("MODEL", AccessMode.READONLY), "model.jar");
            File tsv = new File(aContext.getFolder("TSV", AccessMode.READWRITE), "output.tsv");
            return createEngine(createEngineDescription(BreakIteratorSegmenter.class), createEngineDescription(SnowballStemmer.class), createEngineDescription(ExamplePosAnnotator.class, GenericJarClassifierFactory.PARAM_CLASSIFIER_JAR_PATH, model.getAbsolutePath()), createEngineDescription(ImsCwbWriter.class, ImsCwbWriter.PARAM_TARGET_LOCATION, tsv));
        }
    };
    ParameterSpace pSpace = new ParameterSpace(Dimension.create("corpusPath", CORPUS_PATH), Dimension.create("iterations", 20, 50, 100), Dimension.create("cutoff", 5));
    featureExtractionTask.addImport(preprocessingTask, "XMI");
    trainingTask.addImport(featureExtractionTask, "MODEL");
    analysisTask.addImport(trainingTask, "MODEL");
    DefaultBatchTask batch = new DefaultBatchTask();
    batch.setParameterSpace(pSpace);
    batch.setExecutionPolicy(ExecutionPolicy.USE_EXISTING);
    batch.addTask(preprocessingTask);
    batch.addTask(featureExtractionTask);
    batch.addTask(trainingTask);
    batch.addTask(analysisTask);
    Lab.getInstance().run(batch);
}
Also used : Task(org.dkpro.lab.task.Task) DefaultBatchTask(org.dkpro.lab.task.impl.DefaultBatchTask) UimaTaskBase(org.dkpro.lab.uima.task.impl.UimaTaskBase) TaskContext(org.dkpro.lab.engine.TaskContext) SnowballStemmer(de.tudarmstadt.ukp.dkpro.core.snowball.SnowballStemmer) XmiWriter(de.tudarmstadt.ukp.dkpro.core.io.xmi.XmiWriter) ExecutableTaskBase(org.dkpro.lab.task.impl.ExecutableTaskBase) ExamplePosAnnotator(org.dkpro.lab.ml.example.ExamplePosAnnotator) ParameterSpace(org.dkpro.lab.task.ParameterSpace) BreakIteratorSegmenter(de.tudarmstadt.ukp.dkpro.core.tokit.BreakIteratorSegmenter) ImsCwbWriter(de.tudarmstadt.ukp.dkpro.core.io.imscwb.ImsCwbWriter) File(java.io.File) DefaultBatchTask(org.dkpro.lab.task.impl.DefaultBatchTask) Test(org.junit.Test)

Example 7 with TaskContext

use of org.dkpro.lab.engine.TaskContext in project dkpro-lab by dkpro.

the class CpeExecutionEngine method run.

@Override
public String run(Task aConfiguration) throws ExecutionException, LifeCycleException {
    if (!(aConfiguration instanceof UimaTask)) {
        throw new ExecutionException("This engine can only execute [" + UimaTask.class.getName() + "]");
    }
    UimaTask configuration = (UimaTask) aConfiguration;
    // Create persistence service for injection into analysis components
    TaskContext ctx = contextFactory.createContext(aConfiguration);
    try {
        ResourceManager resMgr = newDefaultResourceManager();
        // Make sure the descriptor is fully resolved. It will be modified and
        // thus should not be modified again afterwards by UIMA.
        AnalysisEngineDescription analysisDesc = configuration.getAnalysisEngineDescription(ctx);
        analysisDesc.resolveImports(resMgr);
        // Scan components that accept the service and bind it to them
        bindResource(analysisDesc, TaskContext.class, TaskContextProvider.class, TaskContextProvider.PARAM_FACTORY_NAME, contextFactory.getId(), TaskContextProvider.PARAM_CONTEXT_ID, ctx.getId());
        CpeBuilder mgr = new CpeBuilder();
        ctx.message("CPE will be using " + Runtime.getRuntime().availableProcessors() + " parallel threads to optimally utilize your cpu cores");
        mgr.setMaxProcessingUnitThreadCount(Runtime.getRuntime().availableProcessors());
        mgr.setReader(configuration.getCollectionReaderDescription(ctx));
        mgr.setAnalysisEngine(analysisDesc);
        StatusCallbackListenerImpl status = new StatusCallbackListenerImpl(ctx);
        CollectionProcessingEngine engine = mgr.createCpe(status);
        // Now the setup is complete
        ctx.getLifeCycleManager().initialize(ctx, aConfiguration);
        // Start recording
        ctx.getLifeCycleManager().begin(ctx, aConfiguration);
        // Run the experiment
        engine.process();
        try {
            synchronized (status) {
                while (status.isProcessing) {
                    status.wait();
                }
            }
        } catch (InterruptedException e) {
            ctx.message("CPE interrupted.");
        }
        if (status.exceptions.size() > 0) {
            throw status.exceptions.get(0);
        }
        // End recording
        ctx.getLifeCycleManager().complete(ctx, aConfiguration);
        return ctx.getId();
    } catch (LifeCycleException e) {
        ctx.getLifeCycleManager().fail(ctx, aConfiguration, e);
        throw e;
    } catch (Throwable e) {
        ctx.getLifeCycleManager().fail(ctx, aConfiguration, e);
        throw new ExecutionException(e);
    } finally {
        if (ctx != null) {
            ctx.getLifeCycleManager().destroy(ctx, aConfiguration);
        }
    }
}
Also used : TaskContext(org.dkpro.lab.engine.TaskContext) AnalysisEngineDescription(org.apache.uima.analysis_engine.AnalysisEngineDescription) ResourceManager(org.apache.uima.resource.ResourceManager) UIMAFramework.newDefaultResourceManager(org.apache.uima.UIMAFramework.newDefaultResourceManager) CollectionProcessingEngine(org.apache.uima.collection.CollectionProcessingEngine) LifeCycleException(org.dkpro.lab.engine.LifeCycleException) UimaTask(org.dkpro.lab.uima.task.UimaTask) ExecutionException(org.dkpro.lab.engine.ExecutionException) CpeBuilder(org.apache.uima.fit.cpe.CpeBuilder)

Example 8 with TaskContext

use of org.dkpro.lab.engine.TaskContext in project dkpro-tc by dkpro.

the class ExperimentCrossValidation method init.

/**
 * Initializes the experiment. This is called automatically before execution. It's not done
 * directly in the constructor, because we want to be able to use setters instead of the
 * three-argument constructor.
 */
protected void init() throws IllegalStateException {
    if (experimentName == null) {
        throw new IllegalStateException("You must set an experiment name");
    }
    if (numFolds < 2) {
        throw new IllegalStateException("Number of folds is not configured correctly. Number of folds needs to be at " + "least 2 (but was " + numFolds + ")");
    }
    // initialize the setup
    initTask = new InitTask();
    initTask.setPreprocessing(getPreprocessing());
    initTask.setOperativeViews(operativeViews);
    initTask.setType(initTask.getType() + "-" + experimentName);
    initTask.setAttribute(TC_TASK_TYPE, TcTaskType.INIT_TRAIN.toString());
    // inner batch task (carried out numFolds times)
    DefaultBatchTask crossValidationTask = new DefaultBatchTask() {

        @Discriminator(name = DIM_FEATURE_MODE)
        private String featureMode;

        @Discriminator(name = DIM_CROSS_VALIDATION_MANUAL_FOLDS)
        private boolean useCrossValidationManualFolds;

        @Override
        public void initialize(TaskContext aContext) {
            super.initialize(aContext);
            File xmiPathRoot = aContext.getFolder(InitTask.OUTPUT_KEY_TRAIN, AccessMode.READONLY);
            Collection<File> files = FileUtils.listFiles(xmiPathRoot, new String[] { "bin" }, true);
            String[] fileNames = new String[files.size()];
            int i = 0;
            for (File f : files) {
                // adding file paths, not names
                fileNames[i] = f.getAbsolutePath();
                i++;
            }
            Arrays.sort(fileNames);
            if (numFolds == LEAVE_ONE_OUT) {
                numFolds = fileNames.length;
            }
            // off
            if (!useCrossValidationManualFolds && fileNames.length < numFolds) {
                xmiPathRoot = createRequestedNumberOfCas(xmiPathRoot, fileNames.length, featureMode);
                files = FileUtils.listFiles(xmiPathRoot, new String[] { "bin" }, true);
                fileNames = new String[files.size()];
                i = 0;
                for (File f : files) {
                    // adding file paths, not names
                    fileNames[i] = f.getAbsolutePath();
                    i++;
                }
            }
            // don't change any names!!
            FoldDimensionBundle<String> foldDim = getFoldDim(fileNames);
            Dimension<File> filesRootDim = Dimension.create(DIM_FILES_ROOT, xmiPathRoot);
            ParameterSpace pSpace = new ParameterSpace(foldDim, filesRootDim);
            setParameterSpace(pSpace);
        }

        /**
         * creates required number of CAS
         *
         * @param xmiPathRoot
         *            input path
         * @param numAvailableJCas
         *            all CAS
         * @param featureMode
         *            the feature mode
         * @return a file
         */
        private File createRequestedNumberOfCas(File xmiPathRoot, int numAvailableJCas, String featureMode) {
            try {
                File outputFolder = FoldUtil.createMinimalSplit(xmiPathRoot.getAbsolutePath(), numFolds, numAvailableJCas, FM_SEQUENCE.equals(featureMode));
                if (outputFolder == null) {
                    throw new NullPointerException("Output folder is null");
                }
                verfiyThatNeededNumberOfCasWasCreated(outputFolder);
                return outputFolder;
            } catch (Exception e) {
                throw new IllegalStateException(e);
            }
        }

        private void verfiyThatNeededNumberOfCasWasCreated(File outputFolder) {
            int numCas = 0;
            File[] listFiles = outputFolder.listFiles();
            if (listFiles == null) {
                throw new NullPointerException("Retrieving files in folder led to a NullPointer");
            }
            for (File f : listFiles) {
                if (f.getName().contains(".bin")) {
                    numCas++;
                }
            }
            if (numCas < numFolds) {
                throw new IllegalStateException("Not enough TextClassificationUnits found to create at least [" + numFolds + "] folds");
            }
        }
    };
    // ================== SUBTASKS OF THE INNER BATCH TASK =======================
    // collecting meta features only on the training data (numFolds times)
    collectionTask = new OutcomeCollectionTask();
    collectionTask.setType(collectionTask.getType() + "-" + experimentName);
    collectionTask.setAttribute(TC_TASK_TYPE, TcTaskType.COLLECTION.toString());
    collectionTask.addImport(initTask, InitTask.OUTPUT_KEY_TRAIN);
    metaTask = new MetaInfoTask();
    metaTask.setOperativeViews(operativeViews);
    metaTask.setType(metaTask.getType() + "-" + experimentName);
    metaTask.setAttribute(TC_TASK_TYPE, TcTaskType.META.toString());
    // extracting features from training data (numFolds times)
    extractFeaturesTrainTask = new ExtractFeaturesTask();
    extractFeaturesTrainTask.setTesting(false);
    extractFeaturesTrainTask.setType(extractFeaturesTrainTask.getType() + "-Train-" + experimentName);
    extractFeaturesTrainTask.addImport(metaTask, MetaInfoTask.META_KEY);
    extractFeaturesTrainTask.addImport(initTask, InitTask.OUTPUT_KEY_TRAIN, ExtractFeaturesTask.INPUT_KEY);
    extractFeaturesTrainTask.addImport(collectionTask, OutcomeCollectionTask.OUTPUT_KEY, ExtractFeaturesTask.COLLECTION_INPUT_KEY);
    extractFeaturesTrainTask.setAttribute(TC_TASK_TYPE, TcTaskType.FEATURE_EXTRACTION_TRAIN.toString());
    // extracting features from test data (numFolds times)
    extractFeaturesTestTask = new ExtractFeaturesTask();
    extractFeaturesTestTask.setTesting(true);
    extractFeaturesTestTask.setType(extractFeaturesTestTask.getType() + "-Test-" + experimentName);
    extractFeaturesTestTask.addImport(metaTask, MetaInfoTask.META_KEY);
    extractFeaturesTestTask.addImport(extractFeaturesTrainTask, ExtractFeaturesTask.OUTPUT_KEY);
    extractFeaturesTestTask.addImport(initTask, InitTask.OUTPUT_KEY_TRAIN, ExtractFeaturesTask.INPUT_KEY);
    extractFeaturesTestTask.addImport(collectionTask, OutcomeCollectionTask.OUTPUT_KEY, ExtractFeaturesTask.COLLECTION_INPUT_KEY);
    extractFeaturesTestTask.setAttribute(TC_TASK_TYPE, TcTaskType.FEATURE_EXTRACTION_TEST.toString());
    // test task operating on the models of the feature extraction train and test tasks
    List<ReportBase> reports = new ArrayList<>();
    reports.add(new BasicResultReport());
    testTask = new DKProTcShallowTestTask(extractFeaturesTrainTask, extractFeaturesTestTask, collectionTask, reports, experimentName);
    testTask.setType(testTask.getType() + "-" + experimentName);
    testTask.setAttribute(TC_TASK_TYPE, TcTaskType.FACADE_TASK.toString());
    if (innerReports != null) {
        for (Class<? extends Report> report : innerReports) {
            testTask.addReport(report);
        }
    }
    testTask.addImport(extractFeaturesTrainTask, ExtractFeaturesTask.OUTPUT_KEY, TEST_TASK_INPUT_KEY_TRAINING_DATA);
    testTask.addImport(extractFeaturesTestTask, ExtractFeaturesTask.OUTPUT_KEY, TEST_TASK_INPUT_KEY_TEST_DATA);
    testTask.addImport(collectionTask, OutcomeCollectionTask.OUTPUT_KEY, Constants.OUTCOMES_INPUT_KEY);
    // ================== CONFIG OF THE INNER BATCH TASK =======================
    crossValidationTask.addImport(initTask, InitTask.OUTPUT_KEY_TRAIN);
    crossValidationTask.setType(crossValidationTask.getType() + "-" + experimentName);
    crossValidationTask.addTask(collectionTask);
    crossValidationTask.addTask(metaTask);
    crossValidationTask.addTask(extractFeaturesTrainTask);
    crossValidationTask.addTask(extractFeaturesTestTask);
    crossValidationTask.addTask(testTask);
    crossValidationTask.setExecutionPolicy(ExecutionPolicy.USE_EXISTING);
    // report of the inner batch task (sums up results for the folds)
    // we want to re-use the old CV report, we need to collect the evaluation.bin files from
    // the test task here (with another report)
    crossValidationTask.addReport(InnerBatchReport.class);
    crossValidationTask.setAttribute(TC_TASK_TYPE, TcTaskType.CROSS_VALIDATION.toString());
    // DKPro Lab issue 38: must be added as *first* task
    addTask(initTask);
    addTask(crossValidationTask);
}
Also used : ReportBase(org.dkpro.lab.reporting.ReportBase) TaskContext(org.dkpro.lab.engine.TaskContext) ArrayList(java.util.ArrayList) MetaInfoTask(org.dkpro.tc.core.task.MetaInfoTask) TextClassificationException(org.dkpro.tc.api.exception.TextClassificationException) InitTask(org.dkpro.tc.core.task.InitTask) ExtractFeaturesTask(org.dkpro.tc.core.task.ExtractFeaturesTask) DKProTcShallowTestTask(org.dkpro.tc.core.task.DKProTcShallowTestTask) BasicResultReport(org.dkpro.tc.ml.report.BasicResultReport) ParameterSpace(org.dkpro.lab.task.ParameterSpace) OutcomeCollectionTask(org.dkpro.tc.core.task.OutcomeCollectionTask) File(java.io.File) DefaultBatchTask(org.dkpro.lab.task.impl.DefaultBatchTask)

Example 9 with TaskContext

use of org.dkpro.lab.engine.TaskContext in project dkpro-tc by dkpro.

the class DeepLearningExperimentCrossValidation method init.

/**
 * Initializes the experiment. This is called automatically before execution. It's not done
 * directly in the constructor, because we want to be able to use setters instead of the
 * three-argument constructor.
 *
 * @throws IllegalStateException
 *             in case of errors
 */
protected void init() throws IllegalStateException {
    if (experimentName == null) {
        throw new IllegalStateException("You must set an experiment name");
    }
    if (numFolds < 2) {
        throw new IllegalStateException("Number of folds is not configured correctly. Number of folds needs to be at " + "least 2 (but was " + numFolds + ")");
    }
    // initialize the setup
    initTask = new InitTaskDeep();
    initTask.setPreprocessing(getPreprocessing());
    initTask.setOperativeViews(operativeViews);
    initTask.setType(initTask.getType() + "-" + experimentName);
    initTask.setAttribute(TC_TASK_TYPE, TcTaskType.INIT_TRAIN.toString());
    // inner batch task (carried out numFolds times)
    DefaultBatchTask crossValidationTask = new DefaultBatchTask() {

        @Discriminator(name = DIM_FEATURE_MODE)
        private String featureMode;

        @Discriminator(name = DIM_CROSS_VALIDATION_MANUAL_FOLDS)
        private boolean useCrossValidationManualFolds;

        @Override
        public void initialize(TaskContext aContext) {
            super.initialize(aContext);
            File xmiPathRoot = aContext.getFolder(InitTask.OUTPUT_KEY_TRAIN, AccessMode.READONLY);
            Collection<File> files = FileUtils.listFiles(xmiPathRoot, new String[] { "bin" }, true);
            String[] fileNames = new String[files.size()];
            int i = 0;
            for (File f : files) {
                // adding file paths, not names
                fileNames[i] = f.getAbsolutePath();
                i++;
            }
            Arrays.sort(fileNames);
            if (numFolds == LEAVE_ONE_OUT) {
                numFolds = fileNames.length;
            }
            // manual mode is turned off
            if (!useCrossValidationManualFolds && fileNames.length < numFolds) {
                xmiPathRoot = createRequestedNumberOfCas(xmiPathRoot, fileNames.length, featureMode);
                files = FileUtils.listFiles(xmiPathRoot, new String[] { "bin" }, true);
                fileNames = new String[files.size()];
                i = 0;
                for (File f : files) {
                    // adding file paths, not names
                    fileNames[i] = f.getAbsolutePath();
                    i++;
                }
            }
            // don't change any names!!
            FoldDimensionBundle<String> foldDim = getFoldDim(fileNames);
            Dimension<File> filesRootDim = Dimension.create(DIM_FILES_ROOT, xmiPathRoot);
            ParameterSpace pSpace = new ParameterSpace(foldDim, filesRootDim);
            setParameterSpace(pSpace);
        }

        /**
         * creates required number of CAS
         *
         * @param xmiPathRoot
         *            input path
         * @param numAvailableJCas
         *            all CAS
         * @param featureMode
         *            the feature mode
         * @return a file
         */
        private File createRequestedNumberOfCas(File xmiPathRoot, int numAvailableJCas, String featureMode) {
            try {
                File outputFolder = FoldUtil.createMinimalSplit(xmiPathRoot.getAbsolutePath(), numFolds, numAvailableJCas, FM_SEQUENCE.equals(featureMode));
                if (outputFolder == null) {
                    throw new NullPointerException("Output folder is null");
                }
                verfiyThatNeededNumberOfCasWasCreated(outputFolder);
                return outputFolder;
            } catch (Exception e) {
                throw new IllegalStateException(e);
            }
        }

        private void verfiyThatNeededNumberOfCasWasCreated(File outputFolder) {
            int numCas = 0;
            File[] listFiles = outputFolder.listFiles();
            if (listFiles == null) {
                throw new NullPointerException("Retrieving files in folder led to a NullPointer");
            }
            for (File f : listFiles) {
                if (f.getName().contains(".bin")) {
                    numCas++;
                }
            }
            if (numCas < numFolds) {
                throw new IllegalStateException("Not enough TextClassificationUnits found to create at least [" + numFolds + "] folds");
            }
        }
    };
    // ================== SUBTASKS OF THE INNER BATCH TASK
    // =======================
    // collecting meta features only on the training data (numFolds times)
    // get some meta data depending on the whole document collection
    preparationTask = new PreparationTask();
    preparationTask.setType(preparationTask.getType() + "-" + experimentName);
    preparationTask.setMachineLearningAdapter(mlAdapter);
    preparationTask.addImport(initTask, InitTask.OUTPUT_KEY_TRAIN, PreparationTask.INPUT_KEY_TRAIN);
    preparationTask.setAttribute(TC_TASK_TYPE, TcTaskType.PREPARATION.toString());
    embeddingTask = new EmbeddingTask();
    embeddingTask.setType(embeddingTask.getType() + "-" + experimentName);
    embeddingTask.addImport(preparationTask, PreparationTask.OUTPUT_KEY, EmbeddingTask.INPUT_MAPPING);
    embeddingTask.setAttribute(TC_TASK_TYPE, TcTaskType.EMBEDDING.toString());
    // feature extraction on training data
    vectorizationTrainTask = new VectorizationTask();
    vectorizationTrainTask.setType(vectorizationTrainTask.getType() + "-Train-" + experimentName);
    vectorizationTrainTask.setTesting(false);
    vectorizationTrainTask.addImport(preparationTask, PreparationTask.OUTPUT_KEY, VectorizationTask.MAPPING_INPUT_KEY);
    vectorizationTrainTask.setAttribute(TC_TASK_TYPE, TcTaskType.VECTORIZATION_TRAIN.toString());
    // feature extraction on test data
    vectorizationTestTask = new VectorizationTask();
    vectorizationTestTask.setType(vectorizationTestTask.getType() + "-Test-" + experimentName);
    vectorizationTestTask.setTesting(true);
    vectorizationTestTask.addImport(preparationTask, PreparationTask.OUTPUT_KEY, VectorizationTask.MAPPING_INPUT_KEY);
    vectorizationTrainTask.setAttribute(TC_TASK_TYPE, TcTaskType.VECTORIZATION_TEST.toString());
    // test task operating on the models of the feature extraction train and
    // test tasks
    learningTask = mlAdapter.getTestTask();
    learningTask.setType(learningTask.getType() + "-" + experimentName);
    learningTask.setAttribute(TC_TASK_TYPE, TcTaskType.MACHINE_LEARNING_ADAPTER.toString());
    if (innerReports != null) {
        for (Class<? extends Report> report : innerReports) {
            learningTask.addReport(report);
        }
    }
    // // always add OutcomeIdReport
    learningTask.addReport(mlAdapter.getOutcomeIdReportClass());
    learningTask.addReport(mlAdapter.getMajorityBaselineIdReportClass());
    learningTask.addReport(mlAdapter.getRandomBaselineIdReportClass());
    learningTask.addReport(mlAdapter.getMetaCollectionReport());
    learningTask.addReport(BasicResultReport.class);
    learningTask.addImport(preparationTask, PreparationTask.OUTPUT_KEY, TcDeepLearningAdapter.PREPARATION_FOLDER);
    learningTask.addImport(vectorizationTrainTask, VectorizationTask.OUTPUT_KEY, Constants.TEST_TASK_INPUT_KEY_TRAINING_DATA);
    learningTask.addImport(vectorizationTestTask, VectorizationTask.OUTPUT_KEY, Constants.TEST_TASK_INPUT_KEY_TEST_DATA);
    learningTask.addImport(embeddingTask, EmbeddingTask.OUTPUT_KEY, TcDeepLearningAdapter.EMBEDDING_FOLDER);
    learningTask.addImport(vectorizationTrainTask, VectorizationTask.OUTPUT_KEY, TcDeepLearningAdapter.VECTORIZIATION_TRAIN_OUTPUT);
    learningTask.addImport(vectorizationTrainTask, VectorizationTask.OUTPUT_KEY, TcDeepLearningAdapter.TARGET_ID_MAPPING_TRAIN);
    learningTask.addImport(vectorizationTestTask, VectorizationTask.OUTPUT_KEY, TcDeepLearningAdapter.VECTORIZIATION_TEST_OUTPUT);
    learningTask.addImport(vectorizationTestTask, VectorizationTask.OUTPUT_KEY, TcDeepLearningAdapter.TARGET_ID_MAPPING_TEST);
    // ================== CONFIG OF THE INNER BATCH TASK
    // =======================
    crossValidationTask.addImport(initTask, InitTask.OUTPUT_KEY_TRAIN);
    crossValidationTask.setType(crossValidationTask.getType() + "-" + experimentName);
    crossValidationTask.addTask(preparationTask);
    crossValidationTask.addTask(embeddingTask);
    crossValidationTask.addTask(vectorizationTrainTask);
    crossValidationTask.addTask(vectorizationTestTask);
    crossValidationTask.addTask(learningTask);
    crossValidationTask.setExecutionPolicy(ExecutionPolicy.USE_EXISTING);
    // report of the inner batch task (sums up results for the folds)
    // we want to re-use the old CV report, we need to collect the
    // evaluation.bin files from
    // the test task here (with another report)
    crossValidationTask.addReport(DeepLearningInnerBatchReport.class);
    crossValidationTask.setAttribute(TC_TASK_TYPE, TcTaskType.CROSS_VALIDATION.toString());
    // DKPro Lab issue 38: must be added as *first* task
    addTask(initTask);
    addTask(crossValidationTask);
}
Also used : TaskContext(org.dkpro.lab.engine.TaskContext) PreparationTask(org.dkpro.tc.core.task.deep.PreparationTask) InitTaskDeep(org.dkpro.tc.core.task.deep.InitTaskDeep) TextClassificationException(org.dkpro.tc.api.exception.TextClassificationException) ParameterSpace(org.dkpro.lab.task.ParameterSpace) VectorizationTask(org.dkpro.tc.core.task.deep.VectorizationTask) File(java.io.File) EmbeddingTask(org.dkpro.tc.core.task.deep.EmbeddingTask) DefaultBatchTask(org.dkpro.lab.task.impl.DefaultBatchTask)

Example 10 with TaskContext

use of org.dkpro.lab.engine.TaskContext in project dkpro-lab by dkpro.

the class SimpleExecutionEngine method run.

@Override
public String run(Task aConfiguration) throws ExecutionException, LifeCycleException {
    if (!(aConfiguration instanceof UimaTask)) {
        throw new ExecutionException("This engine can only execute [" + UimaTask.class.getName() + "]");
    }
    UimaTask configuration = (UimaTask) aConfiguration;
    // Create persistence service for injection into analysis components
    TaskContext ctx = contextFactory.createContext(aConfiguration);
    try {
        ResourceManager resMgr = newDefaultResourceManager();
        // Make sure the descriptor is fully resolved. It will be modified and
        // thus should not be modified again afterwards by UIMA.
        AnalysisEngineDescription analysisDesc = configuration.getAnalysisEngineDescription(ctx);
        analysisDesc.resolveImports(resMgr);
        if (analysisDesc.getMetaData().getName() == null) {
            analysisDesc.getMetaData().setName("Analysis for " + aConfiguration.getType());
        }
        // Scan components that accept the service and bind it to them
        bindResource(analysisDesc, TaskContext.class, TaskContextProvider.class, TaskContextProvider.PARAM_FACTORY_NAME, contextFactory.getId(), TaskContextProvider.PARAM_CONTEXT_ID, ctx.getId());
        // Set up UIMA context & logging
        Logger logger = new UimaLoggingAdapter(ctx);
        UimaContextAdmin uimaCtx = newUimaContext(logger, resMgr, newConfigurationManager());
        // Set up reader
        CollectionReaderDescription readerDesc = configuration.getCollectionReaderDescription(ctx);
        if (readerDesc.getMetaData().getName() == null) {
            readerDesc.getMetaData().setName("Reader for " + aConfiguration.getType());
        }
        Map<String, Object> addReaderParam = new HashMap<String, Object>();
        addReaderParam.put(Resource.PARAM_UIMA_CONTEXT, uimaCtx);
        addReaderParam.put(Resource.PARAM_RESOURCE_MANAGER, resMgr);
        CollectionReader reader = produceCollectionReader(readerDesc, resMgr, addReaderParam);
        // Set up analysis engine
        AnalysisEngine engine;
        if (analysisDesc.isPrimitive()) {
            engine = new PrimitiveAnalysisEngine_impl();
        } else {
            engine = new AggregateAnalysisEngine_impl();
        }
        Map<String, Object> addEngineParam = new HashMap<String, Object>();
        addReaderParam.put(Resource.PARAM_UIMA_CONTEXT, uimaCtx);
        addReaderParam.put(Resource.PARAM_RESOURCE_MANAGER, resMgr);
        engine.initialize(analysisDesc, addEngineParam);
        // Now the setup is complete
        ctx.getLifeCycleManager().initialize(ctx, aConfiguration);
        // Start recording
        ctx.getLifeCycleManager().begin(ctx, aConfiguration);
        // Run the experiment
        // Apply the engine to all documents provided by the reader
        List<ResourceMetaData> metaData = new ArrayList<ResourceMetaData>();
        metaData.add(reader.getMetaData());
        metaData.add(engine.getMetaData());
        CAS cas = CasCreationUtils.createCas(metaData);
        while (reader.hasNext()) {
            reader.getNext(cas);
            engine.process(cas);
            String documentTitle = "";
            Feature documentTitleFeature = cas.getDocumentAnnotation().getType().getFeatureByBaseName("documentTitle");
            if (documentTitleFeature != null) {
                documentTitle = cas.getDocumentAnnotation().getFeatureValueAsString(documentTitleFeature);
            }
            cas.reset();
            Progress[] progresses = reader.getProgress();
            if (progresses != null) {
                for (Progress p : progresses) {
                    ctx.message("Progress " + readerDesc.getImplementationName() + " " + p.getCompleted() + "/" + p.getTotal() + " " + p.getUnit() + " " + "(" + documentTitle + ")");
                }
            }
        }
        // Shut down engine and reader
        engine.collectionProcessComplete();
        reader.close();
        engine.destroy();
        reader.destroy();
        // End recording
        ctx.getLifeCycleManager().complete(ctx, aConfiguration);
        return ctx.getId();
    } catch (LifeCycleException e) {
        ctx.getLifeCycleManager().fail(ctx, aConfiguration, e);
        throw e;
    } catch (Throwable e) {
        ctx.getLifeCycleManager().fail(ctx, aConfiguration, e);
        throw new ExecutionException(e);
    } finally {
        if (ctx != null) {
            ctx.getLifeCycleManager().destroy(ctx, aConfiguration);
        }
    }
}
Also used : HashMap(java.util.HashMap) ArrayList(java.util.ArrayList) LifeCycleException(org.dkpro.lab.engine.LifeCycleException) Logger(org.apache.uima.util.Logger) Feature(org.apache.uima.cas.Feature) PrimitiveAnalysisEngine_impl(org.apache.uima.analysis_engine.impl.PrimitiveAnalysisEngine_impl) UimaTask(org.dkpro.lab.uima.task.UimaTask) ExecutionException(org.dkpro.lab.engine.ExecutionException) UimaLoggingAdapter(org.dkpro.lab.uima.task.impl.UimaLoggingAdapter) Progress(org.apache.uima.util.Progress) TaskContext(org.dkpro.lab.engine.TaskContext) UIMAFramework.produceCollectionReader(org.apache.uima.UIMAFramework.produceCollectionReader) CollectionReader(org.apache.uima.collection.CollectionReader) ResourceManager(org.apache.uima.resource.ResourceManager) UIMAFramework.newDefaultResourceManager(org.apache.uima.UIMAFramework.newDefaultResourceManager) AggregateAnalysisEngine_impl(org.apache.uima.analysis_engine.impl.AggregateAnalysisEngine_impl) CollectionReaderDescription(org.apache.uima.collection.CollectionReaderDescription) CAS(org.apache.uima.cas.CAS) AnalysisEngineDescription(org.apache.uima.analysis_engine.AnalysisEngineDescription) UimaContextAdmin(org.apache.uima.UimaContextAdmin) ResourceMetaData(org.apache.uima.resource.metadata.ResourceMetaData) AnalysisEngine(org.apache.uima.analysis_engine.AnalysisEngine)

Aggregations

TaskContext (org.dkpro.lab.engine.TaskContext)17 DefaultBatchTask (org.dkpro.lab.task.impl.DefaultBatchTask)11 ExecutableTaskBase (org.dkpro.lab.task.impl.ExecutableTaskBase)9 Test (org.junit.Test)9 ParameterSpace (org.dkpro.lab.task.ParameterSpace)7 File (java.io.File)5 ExecutionException (org.dkpro.lab.engine.ExecutionException)5 Task (org.dkpro.lab.task.Task)5 LifeCycleException (org.dkpro.lab.engine.LifeCycleException)4 ArrayList (java.util.ArrayList)3 ImsCwbWriter (de.tudarmstadt.ukp.dkpro.core.io.imscwb.ImsCwbWriter)2 XmiWriter (de.tudarmstadt.ukp.dkpro.core.io.xmi.XmiWriter)2 SnowballStemmer (de.tudarmstadt.ukp.dkpro.core.snowball.SnowballStemmer)2 BreakIteratorSegmenter (de.tudarmstadt.ukp.dkpro.core.tokit.BreakIteratorSegmenter)2 Properties (java.util.Properties)2 UIMAFramework.newDefaultResourceManager (org.apache.uima.UIMAFramework.newDefaultResourceManager)2 AnalysisEngineDescription (org.apache.uima.analysis_engine.AnalysisEngineDescription)2 ResourceManager (org.apache.uima.resource.ResourceManager)2 ExamplePosAnnotator (org.dkpro.lab.ml.example.ExamplePosAnnotator)2 UnresolvedImportException (org.dkpro.lab.storage.UnresolvedImportException)2