Search in sources :

Example 6 with DefaultBatchTask

use of org.dkpro.lab.task.impl.DefaultBatchTask in project dkpro-lab by dkpro.

the class MultiThreadBatchTaskTest method testNested2.

@Test
public void testNested2() throws Exception {
    // BatchTask innerTask = new BatchTask()
    DefaultBatchTask innerTask = new DefaultBatchTask() {

        @Discriminator
        private Integer outer;

        @Override
        public ParameterSpace getParameterSpace() {
            // Dynamically configure parameter space of nested batch task
            Integer[] values = new Integer[outer];
            for (int i = 0; i < outer; i++) {
                values[i] = i;
            }
            Dimension<Integer> innerDim = Dimension.create("inner", values);
            ParameterSpace innerPSpace = new ParameterSpace(innerDim);
            return innerPSpace;
        }

        @Override
        public void setConfiguration(Map<String, Object> aConfig) {
            super.setConfiguration(aConfig);
            System.out.printf("A %10d %s %s%n", this.hashCode(), getType(), aConfig);
        }
    };
    innerTask.addTask(new ConfigDumperTask1());
    Dimension<Integer> outerDim = Dimension.create("outer", 1, 2, 3);
    ParameterSpace outerPSpace = new ParameterSpace(outerDim);
    DefaultBatchTask outerTask = new DefaultBatchTask() {

        @Override
        public void setConfiguration(Map<String, Object> aConfig) {
            super.setConfiguration(aConfig);
            System.out.printf("B %10d %s %s%n", this.hashCode(), getType(), aConfig);
        }
    };
    outerTask.setParameterSpace(outerPSpace);
    outerTask.addTask(innerTask);
    outerTask.addTask(new ConfigDumperTask2());
    Lab.getInstance().run(outerTask);
}
Also used : Map(java.util.Map) DefaultBatchTask(org.dkpro.lab.task.impl.DefaultBatchTask) Test(org.junit.Test)

Example 7 with DefaultBatchTask

use of org.dkpro.lab.task.impl.DefaultBatchTask in project dkpro-lab by dkpro.

the class PosExampleCrf method run.

@Test
public void run() throws Exception {
    // Route logging through log4j
    System.setProperty("org.apache.uima.logger.class", "org.apache.uima.util.impl.Log4jLogger_impl");
    clean();
    Task preprocessingTask = new UimaTaskBase() {

        @Discriminator
        String corpusPath;

        {
            setType("Preprocessing");
        }

        @Override
        public CollectionReaderDescription getCollectionReaderDescription(TaskContext aContext) throws ResourceInitializationException, IOException {
            return createReader(NegraExportReader.class, NegraExportReader.PARAM_SOURCE_LOCATION, corpusPath, NegraExportReader.PARAM_LANGUAGE, "de");
        }

        @Override
        public AnalysisEngineDescription getAnalysisEngineDescription(TaskContext aContext) throws ResourceInitializationException, IOException {
            File xmiDir = aContext.getFolder("XMI", AccessMode.READWRITE);
            return createEngine(createEngine(SnowballStemmer.class), createEngine(XmiWriter.class, XmiWriter.PARAM_TARGET_LOCATION, xmiDir.getAbsolutePath(), XmiWriter.PARAM_COMPRESSION, CompressionMethod.GZIP));
        }
    };
    Task featureExtractionTask = new UimaTaskBase() {

        {
            setType("FeatureExtraction");
        }

        @Override
        public CollectionReaderDescription getCollectionReaderDescription(TaskContext aContext) throws ResourceInitializationException, IOException {
            File xmiDir = aContext.getFolder("XMI", AccessMode.READONLY);
            return createReader(XmiReader.class, XmiReader.PARAM_SOURCE_LOCATION, xmiDir.getAbsolutePath(), XmiReader.PARAM_PATTERNS, new String[] { "[+]**/*.xmi.gz" });
        }

        @Override
        public AnalysisEngineDescription getAnalysisEngineDescription(TaskContext aContext) throws ResourceInitializationException, IOException {
            File modelDir = aContext.getFolder("MODEL", AccessMode.READWRITE);
            return createEngine(createEngineDescription(ExamplePosAnnotator.class, ExamplePosAnnotator.PARAM_DATA_WRITER_FACTORY_CLASS_NAME, DefaultMalletCRFDataWriterFactory.class.getName(), DefaultMalletCRFDataWriterFactory.PARAM_OUTPUT_DIRECTORY, modelDir.getAbsolutePath()));
        }
    };
    Task trainingTask = new ExecutableTaskBase() {

        {
            setType("TrainingTask");
        }

        @Override
        public void execute(TaskContext aContext) throws Exception {
            File dir = aContext.getFolder("MODEL", AccessMode.READWRITE);
            JarClassifierBuilder<?> classifierBuilder = JarClassifierBuilder.fromTrainingDirectory(dir);
            classifierBuilder.trainClassifier(dir, new String[0]);
            classifierBuilder.packageClassifier(dir);
        }
    };
    Task analysisTask = new UimaTaskBase() {

        {
            setType("AnalysisTask");
        }

        @Override
        public CollectionReaderDescription getCollectionReaderDescription(TaskContext aContext) throws ResourceInitializationException, IOException {
            return createReaderDescription(TextReader.class, TextReader.PARAM_SOURCE_LOCATION, "src/test/resources/text", TextReader.PARAM_PATTERNS, new String[] { "[+]**/*.txt" }, TextReader.PARAM_LANGUAGE, "de");
        }

        @Override
        public AnalysisEngineDescription getAnalysisEngineDescription(TaskContext aContext) throws ResourceInitializationException, IOException {
            File model = new File(aContext.getFolder("MODEL", AccessMode.READONLY), "model.jar");
            File tsv = new File(aContext.getFolder("TSV", AccessMode.READWRITE), "output.tsv");
            return createEngine(createEngineDescription(BreakIteratorSegmenter.class), createEngineDescription(SnowballStemmer.class), createEngineDescription(ExamplePosAnnotator.class, GenericJarClassifierFactory.PARAM_CLASSIFIER_JAR_PATH, model.getAbsolutePath()), createEngineDescription(ImsCwbWriter.class, ImsCwbWriter.PARAM_TARGET_LOCATION, tsv));
        }
    };
    ParameterSpace pSpace = new ParameterSpace(Dimension.create("corpusPath", CORPUS_PATH));
    featureExtractionTask.addImport(preprocessingTask, "XMI");
    trainingTask.addImport(featureExtractionTask, "MODEL");
    analysisTask.addImport(trainingTask, "MODEL");
    DefaultBatchTask batch = new DefaultBatchTask();
    batch.setParameterSpace(pSpace);
    batch.setExecutionPolicy(ExecutionPolicy.USE_EXISTING);
    batch.addTask(preprocessingTask);
    batch.addTask(featureExtractionTask);
    batch.addTask(trainingTask);
    batch.addTask(analysisTask);
    Lab.getInstance().run(batch);
}
Also used : Task(org.dkpro.lab.task.Task) DefaultBatchTask(org.dkpro.lab.task.impl.DefaultBatchTask) UimaTaskBase(org.dkpro.lab.uima.task.impl.UimaTaskBase) TaskContext(org.dkpro.lab.engine.TaskContext) SnowballStemmer(de.tudarmstadt.ukp.dkpro.core.snowball.SnowballStemmer) XmiWriter(de.tudarmstadt.ukp.dkpro.core.io.xmi.XmiWriter) ExecutableTaskBase(org.dkpro.lab.task.impl.ExecutableTaskBase) ExamplePosAnnotator(org.dkpro.lab.ml.example.ExamplePosAnnotator) ParameterSpace(org.dkpro.lab.task.ParameterSpace) BreakIteratorSegmenter(de.tudarmstadt.ukp.dkpro.core.tokit.BreakIteratorSegmenter) ImsCwbWriter(de.tudarmstadt.ukp.dkpro.core.io.imscwb.ImsCwbWriter) File(java.io.File) DefaultBatchTask(org.dkpro.lab.task.impl.DefaultBatchTask) Test(org.junit.Test)

Example 8 with DefaultBatchTask

use of org.dkpro.lab.task.impl.DefaultBatchTask in project dkpro-lab by dkpro.

the class PosExampleMaxEnt method run.

@Test
public void run() throws Exception {
    // Route logging through log4j
    System.setProperty("org.apache.uima.logger.class", "org.apache.uima.util.impl.Log4jLogger_impl");
    clean();
    Task preprocessingTask = new UimaTaskBase() {

        @Discriminator
        String corpusPath;

        {
            setType("Preprocessing");
        }

        @Override
        public CollectionReaderDescription getCollectionReaderDescription(TaskContext aContext) throws ResourceInitializationException, IOException {
            return createReader(NegraExportReader.class, NegraExportReader.PARAM_SOURCE_LOCATION, corpusPath, NegraExportReader.PARAM_LANGUAGE, "de");
        }

        @Override
        public AnalysisEngineDescription getAnalysisEngineDescription(TaskContext aContext) throws ResourceInitializationException, IOException {
            File xmiDir = aContext.getFolder("XMI", AccessMode.READWRITE);
            return createEngine(createEngine(SnowballStemmer.class), createEngine(XmiWriter.class, XmiWriter.PARAM_TARGET_LOCATION, xmiDir.getAbsolutePath(), XmiWriter.PARAM_COMPRESSION, CompressionMethod.GZIP));
        }
    };
    Task featureExtractionTask = new UimaTaskBase() {

        {
            setType("FeatureExtraction");
        }

        @Override
        public CollectionReaderDescription getCollectionReaderDescription(TaskContext aContext) throws ResourceInitializationException, IOException {
            File xmiDir = aContext.getFolder("XMI", AccessMode.READONLY);
            return createReader(XmiReader.class, XmiReader.PARAM_SOURCE_LOCATION, xmiDir.getAbsolutePath(), XmiReader.PARAM_PATTERNS, new String[] { "[+]**/*.xmi.gz" });
        }

        @Override
        public AnalysisEngineDescription getAnalysisEngineDescription(TaskContext aContext) throws ResourceInitializationException, IOException {
            File modelDir = aContext.getFolder("MODEL", AccessMode.READWRITE);
            return createEngine(createEngineDescription(ExamplePosAnnotator.class, ExamplePosAnnotator.PARAM_DATA_WRITER_FACTORY_CLASS_NAME, ViterbiDataWriterFactory.class.getName(), ViterbiDataWriterFactory.PARAM_OUTPUT_DIRECTORY, modelDir.getAbsolutePath(), ViterbiDataWriterFactory.PARAM_DELEGATED_DATA_WRITER_FACTORY_CLASS, DefaultMaxentDataWriterFactory.class.getName()));
        }
    };
    Task trainingTask = new ExecutableTaskBase() {

        @Discriminator
        private int iterations;

        @Discriminator
        private int cutoff;

        {
            setType("TrainingTask");
        }

        @Override
        public void execute(TaskContext aContext) throws Exception {
            File dir = aContext.getFolder("MODEL", AccessMode.READWRITE);
            JarClassifierBuilder<?> classifierBuilder = JarClassifierBuilder.fromTrainingDirectory(dir);
            classifierBuilder.trainClassifier(dir, new String[] { String.valueOf(iterations), String.valueOf(cutoff) });
            classifierBuilder.packageClassifier(dir);
        }
    };
    Task analysisTask = new UimaTaskBase() {

        {
            setType("AnalysisTask");
        }

        @Override
        public CollectionReaderDescription getCollectionReaderDescription(TaskContext aContext) throws ResourceInitializationException, IOException {
            return createReaderDescription(TextReader.class, TextReader.PARAM_SOURCE_LOCATION, "src/test/resources/text/**/*.txt", TextReader.PARAM_LANGUAGE, "de");
        }

        @Override
        public AnalysisEngineDescription getAnalysisEngineDescription(TaskContext aContext) throws ResourceInitializationException, IOException {
            File model = new File(aContext.getFolder("MODEL", AccessMode.READONLY), "model.jar");
            File tsv = new File(aContext.getFolder("TSV", AccessMode.READWRITE), "output.tsv");
            return createEngine(createEngineDescription(BreakIteratorSegmenter.class), createEngineDescription(SnowballStemmer.class), createEngineDescription(ExamplePosAnnotator.class, GenericJarClassifierFactory.PARAM_CLASSIFIER_JAR_PATH, model.getAbsolutePath()), createEngineDescription(ImsCwbWriter.class, ImsCwbWriter.PARAM_TARGET_LOCATION, tsv));
        }
    };
    ParameterSpace pSpace = new ParameterSpace(Dimension.create("corpusPath", CORPUS_PATH), Dimension.create("iterations", 20, 50, 100), Dimension.create("cutoff", 5));
    featureExtractionTask.addImport(preprocessingTask, "XMI");
    trainingTask.addImport(featureExtractionTask, "MODEL");
    analysisTask.addImport(trainingTask, "MODEL");
    DefaultBatchTask batch = new DefaultBatchTask();
    batch.setParameterSpace(pSpace);
    batch.setExecutionPolicy(ExecutionPolicy.USE_EXISTING);
    batch.addTask(preprocessingTask);
    batch.addTask(featureExtractionTask);
    batch.addTask(trainingTask);
    batch.addTask(analysisTask);
    Lab.getInstance().run(batch);
}
Also used : Task(org.dkpro.lab.task.Task) DefaultBatchTask(org.dkpro.lab.task.impl.DefaultBatchTask) UimaTaskBase(org.dkpro.lab.uima.task.impl.UimaTaskBase) TaskContext(org.dkpro.lab.engine.TaskContext) SnowballStemmer(de.tudarmstadt.ukp.dkpro.core.snowball.SnowballStemmer) XmiWriter(de.tudarmstadt.ukp.dkpro.core.io.xmi.XmiWriter) ExecutableTaskBase(org.dkpro.lab.task.impl.ExecutableTaskBase) ExamplePosAnnotator(org.dkpro.lab.ml.example.ExamplePosAnnotator) ParameterSpace(org.dkpro.lab.task.ParameterSpace) BreakIteratorSegmenter(de.tudarmstadt.ukp.dkpro.core.tokit.BreakIteratorSegmenter) ImsCwbWriter(de.tudarmstadt.ukp.dkpro.core.io.imscwb.ImsCwbWriter) File(java.io.File) DefaultBatchTask(org.dkpro.lab.task.impl.DefaultBatchTask) Test(org.junit.Test)

Example 9 with DefaultBatchTask

use of org.dkpro.lab.task.impl.DefaultBatchTask in project dkpro-tc by dkpro.

the class ExperimentCrossValidation method init.

/**
 * Initializes the experiment. This is called automatically before execution. It's not done
 * directly in the constructor, because we want to be able to use setters instead of the
 * three-argument constructor.
 */
protected void init() throws IllegalStateException {
    if (experimentName == null) {
        throw new IllegalStateException("You must set an experiment name");
    }
    if (numFolds < 2) {
        throw new IllegalStateException("Number of folds is not configured correctly. Number of folds needs to be at " + "least 2 (but was " + numFolds + ")");
    }
    // initialize the setup
    initTask = new InitTask();
    initTask.setPreprocessing(getPreprocessing());
    initTask.setOperativeViews(operativeViews);
    initTask.setType(initTask.getType() + "-" + experimentName);
    initTask.setAttribute(TC_TASK_TYPE, TcTaskType.INIT_TRAIN.toString());
    // inner batch task (carried out numFolds times)
    DefaultBatchTask crossValidationTask = new DefaultBatchTask() {

        @Discriminator(name = DIM_FEATURE_MODE)
        private String featureMode;

        @Discriminator(name = DIM_CROSS_VALIDATION_MANUAL_FOLDS)
        private boolean useCrossValidationManualFolds;

        @Override
        public void initialize(TaskContext aContext) {
            super.initialize(aContext);
            File xmiPathRoot = aContext.getFolder(InitTask.OUTPUT_KEY_TRAIN, AccessMode.READONLY);
            Collection<File> files = FileUtils.listFiles(xmiPathRoot, new String[] { "bin" }, true);
            String[] fileNames = new String[files.size()];
            int i = 0;
            for (File f : files) {
                // adding file paths, not names
                fileNames[i] = f.getAbsolutePath();
                i++;
            }
            Arrays.sort(fileNames);
            if (numFolds == LEAVE_ONE_OUT) {
                numFolds = fileNames.length;
            }
            // off
            if (!useCrossValidationManualFolds && fileNames.length < numFolds) {
                xmiPathRoot = createRequestedNumberOfCas(xmiPathRoot, fileNames.length, featureMode);
                files = FileUtils.listFiles(xmiPathRoot, new String[] { "bin" }, true);
                fileNames = new String[files.size()];
                i = 0;
                for (File f : files) {
                    // adding file paths, not names
                    fileNames[i] = f.getAbsolutePath();
                    i++;
                }
            }
            // don't change any names!!
            FoldDimensionBundle<String> foldDim = getFoldDim(fileNames);
            Dimension<File> filesRootDim = Dimension.create(DIM_FILES_ROOT, xmiPathRoot);
            ParameterSpace pSpace = new ParameterSpace(foldDim, filesRootDim);
            setParameterSpace(pSpace);
        }

        /**
         * creates required number of CAS
         *
         * @param xmiPathRoot
         *            input path
         * @param numAvailableJCas
         *            all CAS
         * @param featureMode
         *            the feature mode
         * @return a file
         */
        private File createRequestedNumberOfCas(File xmiPathRoot, int numAvailableJCas, String featureMode) {
            try {
                File outputFolder = FoldUtil.createMinimalSplit(xmiPathRoot.getAbsolutePath(), numFolds, numAvailableJCas, FM_SEQUENCE.equals(featureMode));
                if (outputFolder == null) {
                    throw new NullPointerException("Output folder is null");
                }
                verfiyThatNeededNumberOfCasWasCreated(outputFolder);
                return outputFolder;
            } catch (Exception e) {
                throw new IllegalStateException(e);
            }
        }

        private void verfiyThatNeededNumberOfCasWasCreated(File outputFolder) {
            int numCas = 0;
            File[] listFiles = outputFolder.listFiles();
            if (listFiles == null) {
                throw new NullPointerException("Retrieving files in folder led to a NullPointer");
            }
            for (File f : listFiles) {
                if (f.getName().contains(".bin")) {
                    numCas++;
                }
            }
            if (numCas < numFolds) {
                throw new IllegalStateException("Not enough TextClassificationUnits found to create at least [" + numFolds + "] folds");
            }
        }
    };
    // ================== SUBTASKS OF THE INNER BATCH TASK =======================
    // collecting meta features only on the training data (numFolds times)
    collectionTask = new OutcomeCollectionTask();
    collectionTask.setType(collectionTask.getType() + "-" + experimentName);
    collectionTask.setAttribute(TC_TASK_TYPE, TcTaskType.COLLECTION.toString());
    collectionTask.addImport(initTask, InitTask.OUTPUT_KEY_TRAIN);
    metaTask = new MetaInfoTask();
    metaTask.setOperativeViews(operativeViews);
    metaTask.setType(metaTask.getType() + "-" + experimentName);
    metaTask.setAttribute(TC_TASK_TYPE, TcTaskType.META.toString());
    // extracting features from training data (numFolds times)
    extractFeaturesTrainTask = new ExtractFeaturesTask();
    extractFeaturesTrainTask.setTesting(false);
    extractFeaturesTrainTask.setType(extractFeaturesTrainTask.getType() + "-Train-" + experimentName);
    extractFeaturesTrainTask.addImport(metaTask, MetaInfoTask.META_KEY);
    extractFeaturesTrainTask.addImport(initTask, InitTask.OUTPUT_KEY_TRAIN, ExtractFeaturesTask.INPUT_KEY);
    extractFeaturesTrainTask.addImport(collectionTask, OutcomeCollectionTask.OUTPUT_KEY, ExtractFeaturesTask.COLLECTION_INPUT_KEY);
    extractFeaturesTrainTask.setAttribute(TC_TASK_TYPE, TcTaskType.FEATURE_EXTRACTION_TRAIN.toString());
    // extracting features from test data (numFolds times)
    extractFeaturesTestTask = new ExtractFeaturesTask();
    extractFeaturesTestTask.setTesting(true);
    extractFeaturesTestTask.setType(extractFeaturesTestTask.getType() + "-Test-" + experimentName);
    extractFeaturesTestTask.addImport(metaTask, MetaInfoTask.META_KEY);
    extractFeaturesTestTask.addImport(extractFeaturesTrainTask, ExtractFeaturesTask.OUTPUT_KEY);
    extractFeaturesTestTask.addImport(initTask, InitTask.OUTPUT_KEY_TRAIN, ExtractFeaturesTask.INPUT_KEY);
    extractFeaturesTestTask.addImport(collectionTask, OutcomeCollectionTask.OUTPUT_KEY, ExtractFeaturesTask.COLLECTION_INPUT_KEY);
    extractFeaturesTestTask.setAttribute(TC_TASK_TYPE, TcTaskType.FEATURE_EXTRACTION_TEST.toString());
    // test task operating on the models of the feature extraction train and test tasks
    List<ReportBase> reports = new ArrayList<>();
    reports.add(new BasicResultReport());
    testTask = new DKProTcShallowTestTask(extractFeaturesTrainTask, extractFeaturesTestTask, collectionTask, reports, experimentName);
    testTask.setType(testTask.getType() + "-" + experimentName);
    testTask.setAttribute(TC_TASK_TYPE, TcTaskType.FACADE_TASK.toString());
    if (innerReports != null) {
        for (Class<? extends Report> report : innerReports) {
            testTask.addReport(report);
        }
    }
    testTask.addImport(extractFeaturesTrainTask, ExtractFeaturesTask.OUTPUT_KEY, TEST_TASK_INPUT_KEY_TRAINING_DATA);
    testTask.addImport(extractFeaturesTestTask, ExtractFeaturesTask.OUTPUT_KEY, TEST_TASK_INPUT_KEY_TEST_DATA);
    testTask.addImport(collectionTask, OutcomeCollectionTask.OUTPUT_KEY, Constants.OUTCOMES_INPUT_KEY);
    // ================== CONFIG OF THE INNER BATCH TASK =======================
    crossValidationTask.addImport(initTask, InitTask.OUTPUT_KEY_TRAIN);
    crossValidationTask.setType(crossValidationTask.getType() + "-" + experimentName);
    crossValidationTask.addTask(collectionTask);
    crossValidationTask.addTask(metaTask);
    crossValidationTask.addTask(extractFeaturesTrainTask);
    crossValidationTask.addTask(extractFeaturesTestTask);
    crossValidationTask.addTask(testTask);
    crossValidationTask.setExecutionPolicy(ExecutionPolicy.USE_EXISTING);
    // report of the inner batch task (sums up results for the folds)
    // we want to re-use the old CV report, we need to collect the evaluation.bin files from
    // the test task here (with another report)
    crossValidationTask.addReport(InnerBatchReport.class);
    crossValidationTask.setAttribute(TC_TASK_TYPE, TcTaskType.CROSS_VALIDATION.toString());
    // DKPro Lab issue 38: must be added as *first* task
    addTask(initTask);
    addTask(crossValidationTask);
}
Also used : ReportBase(org.dkpro.lab.reporting.ReportBase) TaskContext(org.dkpro.lab.engine.TaskContext) ArrayList(java.util.ArrayList) MetaInfoTask(org.dkpro.tc.core.task.MetaInfoTask) TextClassificationException(org.dkpro.tc.api.exception.TextClassificationException) InitTask(org.dkpro.tc.core.task.InitTask) ExtractFeaturesTask(org.dkpro.tc.core.task.ExtractFeaturesTask) DKProTcShallowTestTask(org.dkpro.tc.core.task.DKProTcShallowTestTask) BasicResultReport(org.dkpro.tc.ml.report.BasicResultReport) ParameterSpace(org.dkpro.lab.task.ParameterSpace) OutcomeCollectionTask(org.dkpro.tc.core.task.OutcomeCollectionTask) File(java.io.File) DefaultBatchTask(org.dkpro.lab.task.impl.DefaultBatchTask)

Example 10 with DefaultBatchTask

use of org.dkpro.lab.task.impl.DefaultBatchTask in project dkpro-tc by dkpro.

the class DeepLearningExperimentCrossValidation method init.

/**
 * Initializes the experiment. This is called automatically before execution. It's not done
 * directly in the constructor, because we want to be able to use setters instead of the
 * three-argument constructor.
 *
 * @throws IllegalStateException
 *             in case of errors
 */
protected void init() throws IllegalStateException {
    if (experimentName == null) {
        throw new IllegalStateException("You must set an experiment name");
    }
    if (numFolds < 2) {
        throw new IllegalStateException("Number of folds is not configured correctly. Number of folds needs to be at " + "least 2 (but was " + numFolds + ")");
    }
    // initialize the setup
    initTask = new InitTaskDeep();
    initTask.setPreprocessing(getPreprocessing());
    initTask.setOperativeViews(operativeViews);
    initTask.setType(initTask.getType() + "-" + experimentName);
    initTask.setAttribute(TC_TASK_TYPE, TcTaskType.INIT_TRAIN.toString());
    // inner batch task (carried out numFolds times)
    DefaultBatchTask crossValidationTask = new DefaultBatchTask() {

        @Discriminator(name = DIM_FEATURE_MODE)
        private String featureMode;

        @Discriminator(name = DIM_CROSS_VALIDATION_MANUAL_FOLDS)
        private boolean useCrossValidationManualFolds;

        @Override
        public void initialize(TaskContext aContext) {
            super.initialize(aContext);
            File xmiPathRoot = aContext.getFolder(InitTask.OUTPUT_KEY_TRAIN, AccessMode.READONLY);
            Collection<File> files = FileUtils.listFiles(xmiPathRoot, new String[] { "bin" }, true);
            String[] fileNames = new String[files.size()];
            int i = 0;
            for (File f : files) {
                // adding file paths, not names
                fileNames[i] = f.getAbsolutePath();
                i++;
            }
            Arrays.sort(fileNames);
            if (numFolds == LEAVE_ONE_OUT) {
                numFolds = fileNames.length;
            }
            // manual mode is turned off
            if (!useCrossValidationManualFolds && fileNames.length < numFolds) {
                xmiPathRoot = createRequestedNumberOfCas(xmiPathRoot, fileNames.length, featureMode);
                files = FileUtils.listFiles(xmiPathRoot, new String[] { "bin" }, true);
                fileNames = new String[files.size()];
                i = 0;
                for (File f : files) {
                    // adding file paths, not names
                    fileNames[i] = f.getAbsolutePath();
                    i++;
                }
            }
            // don't change any names!!
            FoldDimensionBundle<String> foldDim = getFoldDim(fileNames);
            Dimension<File> filesRootDim = Dimension.create(DIM_FILES_ROOT, xmiPathRoot);
            ParameterSpace pSpace = new ParameterSpace(foldDim, filesRootDim);
            setParameterSpace(pSpace);
        }

        /**
         * creates required number of CAS
         *
         * @param xmiPathRoot
         *            input path
         * @param numAvailableJCas
         *            all CAS
         * @param featureMode
         *            the feature mode
         * @return a file
         */
        private File createRequestedNumberOfCas(File xmiPathRoot, int numAvailableJCas, String featureMode) {
            try {
                File outputFolder = FoldUtil.createMinimalSplit(xmiPathRoot.getAbsolutePath(), numFolds, numAvailableJCas, FM_SEQUENCE.equals(featureMode));
                if (outputFolder == null) {
                    throw new NullPointerException("Output folder is null");
                }
                verfiyThatNeededNumberOfCasWasCreated(outputFolder);
                return outputFolder;
            } catch (Exception e) {
                throw new IllegalStateException(e);
            }
        }

        private void verfiyThatNeededNumberOfCasWasCreated(File outputFolder) {
            int numCas = 0;
            File[] listFiles = outputFolder.listFiles();
            if (listFiles == null) {
                throw new NullPointerException("Retrieving files in folder led to a NullPointer");
            }
            for (File f : listFiles) {
                if (f.getName().contains(".bin")) {
                    numCas++;
                }
            }
            if (numCas < numFolds) {
                throw new IllegalStateException("Not enough TextClassificationUnits found to create at least [" + numFolds + "] folds");
            }
        }
    };
    // ================== SUBTASKS OF THE INNER BATCH TASK
    // =======================
    // collecting meta features only on the training data (numFolds times)
    // get some meta data depending on the whole document collection
    preparationTask = new PreparationTask();
    preparationTask.setType(preparationTask.getType() + "-" + experimentName);
    preparationTask.setMachineLearningAdapter(mlAdapter);
    preparationTask.addImport(initTask, InitTask.OUTPUT_KEY_TRAIN, PreparationTask.INPUT_KEY_TRAIN);
    preparationTask.setAttribute(TC_TASK_TYPE, TcTaskType.PREPARATION.toString());
    embeddingTask = new EmbeddingTask();
    embeddingTask.setType(embeddingTask.getType() + "-" + experimentName);
    embeddingTask.addImport(preparationTask, PreparationTask.OUTPUT_KEY, EmbeddingTask.INPUT_MAPPING);
    embeddingTask.setAttribute(TC_TASK_TYPE, TcTaskType.EMBEDDING.toString());
    // feature extraction on training data
    vectorizationTrainTask = new VectorizationTask();
    vectorizationTrainTask.setType(vectorizationTrainTask.getType() + "-Train-" + experimentName);
    vectorizationTrainTask.setTesting(false);
    vectorizationTrainTask.addImport(preparationTask, PreparationTask.OUTPUT_KEY, VectorizationTask.MAPPING_INPUT_KEY);
    vectorizationTrainTask.setAttribute(TC_TASK_TYPE, TcTaskType.VECTORIZATION_TRAIN.toString());
    // feature extraction on test data
    vectorizationTestTask = new VectorizationTask();
    vectorizationTestTask.setType(vectorizationTestTask.getType() + "-Test-" + experimentName);
    vectorizationTestTask.setTesting(true);
    vectorizationTestTask.addImport(preparationTask, PreparationTask.OUTPUT_KEY, VectorizationTask.MAPPING_INPUT_KEY);
    vectorizationTrainTask.setAttribute(TC_TASK_TYPE, TcTaskType.VECTORIZATION_TEST.toString());
    // test task operating on the models of the feature extraction train and
    // test tasks
    learningTask = mlAdapter.getTestTask();
    learningTask.setType(learningTask.getType() + "-" + experimentName);
    learningTask.setAttribute(TC_TASK_TYPE, TcTaskType.MACHINE_LEARNING_ADAPTER.toString());
    if (innerReports != null) {
        for (Class<? extends Report> report : innerReports) {
            learningTask.addReport(report);
        }
    }
    // // always add OutcomeIdReport
    learningTask.addReport(mlAdapter.getOutcomeIdReportClass());
    learningTask.addReport(mlAdapter.getMajorityBaselineIdReportClass());
    learningTask.addReport(mlAdapter.getRandomBaselineIdReportClass());
    learningTask.addReport(mlAdapter.getMetaCollectionReport());
    learningTask.addReport(BasicResultReport.class);
    learningTask.addImport(preparationTask, PreparationTask.OUTPUT_KEY, TcDeepLearningAdapter.PREPARATION_FOLDER);
    learningTask.addImport(vectorizationTrainTask, VectorizationTask.OUTPUT_KEY, Constants.TEST_TASK_INPUT_KEY_TRAINING_DATA);
    learningTask.addImport(vectorizationTestTask, VectorizationTask.OUTPUT_KEY, Constants.TEST_TASK_INPUT_KEY_TEST_DATA);
    learningTask.addImport(embeddingTask, EmbeddingTask.OUTPUT_KEY, TcDeepLearningAdapter.EMBEDDING_FOLDER);
    learningTask.addImport(vectorizationTrainTask, VectorizationTask.OUTPUT_KEY, TcDeepLearningAdapter.VECTORIZIATION_TRAIN_OUTPUT);
    learningTask.addImport(vectorizationTrainTask, VectorizationTask.OUTPUT_KEY, TcDeepLearningAdapter.TARGET_ID_MAPPING_TRAIN);
    learningTask.addImport(vectorizationTestTask, VectorizationTask.OUTPUT_KEY, TcDeepLearningAdapter.VECTORIZIATION_TEST_OUTPUT);
    learningTask.addImport(vectorizationTestTask, VectorizationTask.OUTPUT_KEY, TcDeepLearningAdapter.TARGET_ID_MAPPING_TEST);
    // ================== CONFIG OF THE INNER BATCH TASK
    // =======================
    crossValidationTask.addImport(initTask, InitTask.OUTPUT_KEY_TRAIN);
    crossValidationTask.setType(crossValidationTask.getType() + "-" + experimentName);
    crossValidationTask.addTask(preparationTask);
    crossValidationTask.addTask(embeddingTask);
    crossValidationTask.addTask(vectorizationTrainTask);
    crossValidationTask.addTask(vectorizationTestTask);
    crossValidationTask.addTask(learningTask);
    crossValidationTask.setExecutionPolicy(ExecutionPolicy.USE_EXISTING);
    // report of the inner batch task (sums up results for the folds)
    // we want to re-use the old CV report, we need to collect the
    // evaluation.bin files from
    // the test task here (with another report)
    crossValidationTask.addReport(DeepLearningInnerBatchReport.class);
    crossValidationTask.setAttribute(TC_TASK_TYPE, TcTaskType.CROSS_VALIDATION.toString());
    // DKPro Lab issue 38: must be added as *first* task
    addTask(initTask);
    addTask(crossValidationTask);
}
Also used : TaskContext(org.dkpro.lab.engine.TaskContext) PreparationTask(org.dkpro.tc.core.task.deep.PreparationTask) InitTaskDeep(org.dkpro.tc.core.task.deep.InitTaskDeep) TextClassificationException(org.dkpro.tc.api.exception.TextClassificationException) ParameterSpace(org.dkpro.lab.task.ParameterSpace) VectorizationTask(org.dkpro.tc.core.task.deep.VectorizationTask) File(java.io.File) EmbeddingTask(org.dkpro.tc.core.task.deep.EmbeddingTask) DefaultBatchTask(org.dkpro.lab.task.impl.DefaultBatchTask)

Aggregations

DefaultBatchTask (org.dkpro.lab.task.impl.DefaultBatchTask)20 Test (org.junit.Test)17 ParameterSpace (org.dkpro.lab.task.ParameterSpace)12 TaskContext (org.dkpro.lab.engine.TaskContext)11 ExecutableTaskBase (org.dkpro.lab.task.impl.ExecutableTaskBase)9 File (java.io.File)6 Map (java.util.Map)5 Task (org.dkpro.lab.task.Task)5 ImsCwbWriter (de.tudarmstadt.ukp.dkpro.core.io.imscwb.ImsCwbWriter)2 XmiWriter (de.tudarmstadt.ukp.dkpro.core.io.xmi.XmiWriter)2 SnowballStemmer (de.tudarmstadt.ukp.dkpro.core.snowball.SnowballStemmer)2 BreakIteratorSegmenter (de.tudarmstadt.ukp.dkpro.core.tokit.BreakIteratorSegmenter)2 Properties (java.util.Properties)2 Lab (org.dkpro.lab.Lab)2 ExamplePosAnnotator (org.dkpro.lab.ml.example.ExamplePosAnnotator)2 PropertiesAdapter (org.dkpro.lab.storage.impl.PropertiesAdapter)2 UimaTaskBase (org.dkpro.lab.uima.task.impl.UimaTaskBase)2 TextClassificationException (org.dkpro.tc.api.exception.TextClassificationException)2 ArrayList (java.util.ArrayList)1 Collection (java.util.Collection)1