Search in sources :

Example 1 with ExtractFeaturesTask

use of org.dkpro.tc.core.task.ExtractFeaturesTask in project dkpro-tc by dkpro.

the class ExperimentCrossValidation method init.

/**
 * Initializes the experiment. This is called automatically before execution. It's not done
 * directly in the constructor, because we want to be able to use setters instead of the
 * three-argument constructor.
 */
protected void init() throws IllegalStateException {
    if (experimentName == null) {
        throw new IllegalStateException("You must set an experiment name");
    }
    if (numFolds < 2) {
        throw new IllegalStateException("Number of folds is not configured correctly. Number of folds needs to be at " + "least 2 (but was " + numFolds + ")");
    }
    // initialize the setup
    initTask = new InitTask();
    initTask.setPreprocessing(getPreprocessing());
    initTask.setOperativeViews(operativeViews);
    initTask.setType(initTask.getType() + "-" + experimentName);
    initTask.setAttribute(TC_TASK_TYPE, TcTaskType.INIT_TRAIN.toString());
    // inner batch task (carried out numFolds times)
    DefaultBatchTask crossValidationTask = new DefaultBatchTask() {

        @Discriminator(name = DIM_FEATURE_MODE)
        private String featureMode;

        @Discriminator(name = DIM_CROSS_VALIDATION_MANUAL_FOLDS)
        private boolean useCrossValidationManualFolds;

        @Override
        public void initialize(TaskContext aContext) {
            super.initialize(aContext);
            File xmiPathRoot = aContext.getFolder(InitTask.OUTPUT_KEY_TRAIN, AccessMode.READONLY);
            Collection<File> files = FileUtils.listFiles(xmiPathRoot, new String[] { "bin" }, true);
            String[] fileNames = new String[files.size()];
            int i = 0;
            for (File f : files) {
                // adding file paths, not names
                fileNames[i] = f.getAbsolutePath();
                i++;
            }
            Arrays.sort(fileNames);
            if (numFolds == LEAVE_ONE_OUT) {
                numFolds = fileNames.length;
            }
            // off
            if (!useCrossValidationManualFolds && fileNames.length < numFolds) {
                xmiPathRoot = createRequestedNumberOfCas(xmiPathRoot, fileNames.length, featureMode);
                files = FileUtils.listFiles(xmiPathRoot, new String[] { "bin" }, true);
                fileNames = new String[files.size()];
                i = 0;
                for (File f : files) {
                    // adding file paths, not names
                    fileNames[i] = f.getAbsolutePath();
                    i++;
                }
            }
            // don't change any names!!
            FoldDimensionBundle<String> foldDim = getFoldDim(fileNames);
            Dimension<File> filesRootDim = Dimension.create(DIM_FILES_ROOT, xmiPathRoot);
            ParameterSpace pSpace = new ParameterSpace(foldDim, filesRootDim);
            setParameterSpace(pSpace);
        }

        /**
         * creates required number of CAS
         *
         * @param xmiPathRoot
         *            input path
         * @param numAvailableJCas
         *            all CAS
         * @param featureMode
         *            the feature mode
         * @return a file
         */
        private File createRequestedNumberOfCas(File xmiPathRoot, int numAvailableJCas, String featureMode) {
            try {
                File outputFolder = FoldUtil.createMinimalSplit(xmiPathRoot.getAbsolutePath(), numFolds, numAvailableJCas, FM_SEQUENCE.equals(featureMode));
                if (outputFolder == null) {
                    throw new NullPointerException("Output folder is null");
                }
                verfiyThatNeededNumberOfCasWasCreated(outputFolder);
                return outputFolder;
            } catch (Exception e) {
                throw new IllegalStateException(e);
            }
        }

        private void verfiyThatNeededNumberOfCasWasCreated(File outputFolder) {
            int numCas = 0;
            File[] listFiles = outputFolder.listFiles();
            if (listFiles == null) {
                throw new NullPointerException("Retrieving files in folder led to a NullPointer");
            }
            for (File f : listFiles) {
                if (f.getName().contains(".bin")) {
                    numCas++;
                }
            }
            if (numCas < numFolds) {
                throw new IllegalStateException("Not enough TextClassificationUnits found to create at least [" + numFolds + "] folds");
            }
        }
    };
    // ================== SUBTASKS OF THE INNER BATCH TASK =======================
    // collecting meta features only on the training data (numFolds times)
    collectionTask = new OutcomeCollectionTask();
    collectionTask.setType(collectionTask.getType() + "-" + experimentName);
    collectionTask.setAttribute(TC_TASK_TYPE, TcTaskType.COLLECTION.toString());
    collectionTask.addImport(initTask, InitTask.OUTPUT_KEY_TRAIN);
    metaTask = new MetaInfoTask();
    metaTask.setOperativeViews(operativeViews);
    metaTask.setType(metaTask.getType() + "-" + experimentName);
    metaTask.setAttribute(TC_TASK_TYPE, TcTaskType.META.toString());
    // extracting features from training data (numFolds times)
    extractFeaturesTrainTask = new ExtractFeaturesTask();
    extractFeaturesTrainTask.setTesting(false);
    extractFeaturesTrainTask.setType(extractFeaturesTrainTask.getType() + "-Train-" + experimentName);
    extractFeaturesTrainTask.addImport(metaTask, MetaInfoTask.META_KEY);
    extractFeaturesTrainTask.addImport(initTask, InitTask.OUTPUT_KEY_TRAIN, ExtractFeaturesTask.INPUT_KEY);
    extractFeaturesTrainTask.addImport(collectionTask, OutcomeCollectionTask.OUTPUT_KEY, ExtractFeaturesTask.COLLECTION_INPUT_KEY);
    extractFeaturesTrainTask.setAttribute(TC_TASK_TYPE, TcTaskType.FEATURE_EXTRACTION_TRAIN.toString());
    // extracting features from test data (numFolds times)
    extractFeaturesTestTask = new ExtractFeaturesTask();
    extractFeaturesTestTask.setTesting(true);
    extractFeaturesTestTask.setType(extractFeaturesTestTask.getType() + "-Test-" + experimentName);
    extractFeaturesTestTask.addImport(metaTask, MetaInfoTask.META_KEY);
    extractFeaturesTestTask.addImport(extractFeaturesTrainTask, ExtractFeaturesTask.OUTPUT_KEY);
    extractFeaturesTestTask.addImport(initTask, InitTask.OUTPUT_KEY_TRAIN, ExtractFeaturesTask.INPUT_KEY);
    extractFeaturesTestTask.addImport(collectionTask, OutcomeCollectionTask.OUTPUT_KEY, ExtractFeaturesTask.COLLECTION_INPUT_KEY);
    extractFeaturesTestTask.setAttribute(TC_TASK_TYPE, TcTaskType.FEATURE_EXTRACTION_TEST.toString());
    // test task operating on the models of the feature extraction train and test tasks
    List<ReportBase> reports = new ArrayList<>();
    reports.add(new BasicResultReport());
    testTask = new DKProTcShallowTestTask(extractFeaturesTrainTask, extractFeaturesTestTask, collectionTask, reports, experimentName);
    testTask.setType(testTask.getType() + "-" + experimentName);
    testTask.setAttribute(TC_TASK_TYPE, TcTaskType.FACADE_TASK.toString());
    if (innerReports != null) {
        for (Class<? extends Report> report : innerReports) {
            testTask.addReport(report);
        }
    }
    testTask.addImport(extractFeaturesTrainTask, ExtractFeaturesTask.OUTPUT_KEY, TEST_TASK_INPUT_KEY_TRAINING_DATA);
    testTask.addImport(extractFeaturesTestTask, ExtractFeaturesTask.OUTPUT_KEY, TEST_TASK_INPUT_KEY_TEST_DATA);
    testTask.addImport(collectionTask, OutcomeCollectionTask.OUTPUT_KEY, Constants.OUTCOMES_INPUT_KEY);
    // ================== CONFIG OF THE INNER BATCH TASK =======================
    crossValidationTask.addImport(initTask, InitTask.OUTPUT_KEY_TRAIN);
    crossValidationTask.setType(crossValidationTask.getType() + "-" + experimentName);
    crossValidationTask.addTask(collectionTask);
    crossValidationTask.addTask(metaTask);
    crossValidationTask.addTask(extractFeaturesTrainTask);
    crossValidationTask.addTask(extractFeaturesTestTask);
    crossValidationTask.addTask(testTask);
    crossValidationTask.setExecutionPolicy(ExecutionPolicy.USE_EXISTING);
    // report of the inner batch task (sums up results for the folds)
    // we want to re-use the old CV report, we need to collect the evaluation.bin files from
    // the test task here (with another report)
    crossValidationTask.addReport(InnerBatchReport.class);
    crossValidationTask.setAttribute(TC_TASK_TYPE, TcTaskType.CROSS_VALIDATION.toString());
    // DKPro Lab issue 38: must be added as *first* task
    addTask(initTask);
    addTask(crossValidationTask);
}
Also used : ReportBase(org.dkpro.lab.reporting.ReportBase) TaskContext(org.dkpro.lab.engine.TaskContext) ArrayList(java.util.ArrayList) MetaInfoTask(org.dkpro.tc.core.task.MetaInfoTask) TextClassificationException(org.dkpro.tc.api.exception.TextClassificationException) InitTask(org.dkpro.tc.core.task.InitTask) ExtractFeaturesTask(org.dkpro.tc.core.task.ExtractFeaturesTask) DKProTcShallowTestTask(org.dkpro.tc.core.task.DKProTcShallowTestTask) BasicResultReport(org.dkpro.tc.ml.report.BasicResultReport) ParameterSpace(org.dkpro.lab.task.ParameterSpace) OutcomeCollectionTask(org.dkpro.tc.core.task.OutcomeCollectionTask) File(java.io.File) DefaultBatchTask(org.dkpro.lab.task.impl.DefaultBatchTask)

Example 2 with ExtractFeaturesTask

use of org.dkpro.tc.core.task.ExtractFeaturesTask in project dkpro-tc by dkpro.

the class ExperimentSaveModel method init.

/**
 * Initializes the experiment. This is called automatically before execution. It's not done
 * directly in the constructor, because we want to be able to use setters instead of the
 * three-argument constructor.
 *
 * @throws IllegalStateException
 *             if not all necessary arguments have been set.
 */
protected void init() {
    if (experimentName == null) {
        throw new IllegalStateException("You must set an experiment name");
    }
    // init the train part of the experiment
    initTask = new InitTask();
    initTask.setPreprocessing(getPreprocessing());
    initTask.setOperativeViews(operativeViews);
    initTask.setTesting(false);
    initTask.setType(initTask.getType() + "-Train-" + experimentName);
    initTask.setAttribute(TC_TASK_TYPE, TcTaskType.INIT_TRAIN.toString());
    collectionTask = new OutcomeCollectionTask();
    collectionTask.setType(collectionTask.getType() + "-" + experimentName);
    collectionTask.setAttribute(TC_TASK_TYPE, TcTaskType.COLLECTION.toString());
    collectionTask.addImport(initTask, InitTask.OUTPUT_KEY_TRAIN);
    metaTask = new MetaInfoTask();
    metaTask.setOperativeViews(operativeViews);
    metaTask.setType(metaTask.getType() + "-" + experimentName);
    metaTask.setAttribute(TC_TASK_TYPE, TcTaskType.META.toString());
    metaTask.addImport(initTask, InitTask.OUTPUT_KEY_TRAIN, MetaInfoTask.INPUT_KEY);
    // feature extraction on training data
    featuresTrainTask = new ExtractFeaturesTask();
    featuresTrainTask.setType(featuresTrainTask.getType() + "-Train-" + experimentName);
    featuresTrainTask.addImport(metaTask, MetaInfoTask.META_KEY);
    featuresTrainTask.addImport(initTask, InitTask.OUTPUT_KEY_TRAIN, ExtractFeaturesTask.INPUT_KEY);
    featuresTrainTask.setAttribute(TC_TASK_TYPE, TcTaskType.FEATURE_EXTRACTION_TRAIN.toString());
    featuresTrainTask.addImport(collectionTask, OutcomeCollectionTask.OUTPUT_KEY, ExtractFeaturesTask.COLLECTION_INPUT_KEY);
    // feature extraction and prediction on test data
    try {
        saveModelTask = new DKProTcShallowSerializationTask(metaTask, featuresTrainTask, collectionTask, outputFolder, experimentName);
        saveModelTask.setType(saveModelTask.getType() + "-" + experimentName);
        saveModelTask.addImport(metaTask, MetaInfoTask.META_KEY);
        saveModelTask.addImport(featuresTrainTask, ExtractFeaturesTask.OUTPUT_KEY, Constants.TEST_TASK_INPUT_KEY_TRAINING_DATA);
        saveModelTask.addImport(collectionTask, OutcomeCollectionTask.OUTPUT_KEY, Constants.OUTCOMES_INPUT_KEY);
        saveModelTask.setAttribute(TC_TASK_TYPE, TcTaskType.FACADE_TASK.toString());
    } catch (Exception e) {
        throw new IllegalStateException(e);
    }
    // DKPro Lab issue 38: must be added as *first* task
    addTask(initTask);
    addTask(collectionTask);
    addTask(metaTask);
    addTask(featuresTrainTask);
    addTask(saveModelTask);
}
Also used : DKProTcShallowSerializationTask(org.dkpro.tc.core.task.DKProTcShallowSerializationTask) OutcomeCollectionTask(org.dkpro.tc.core.task.OutcomeCollectionTask) MetaInfoTask(org.dkpro.tc.core.task.MetaInfoTask) TextClassificationException(org.dkpro.tc.api.exception.TextClassificationException) InitTask(org.dkpro.tc.core.task.InitTask) ExtractFeaturesTask(org.dkpro.tc.core.task.ExtractFeaturesTask)

Example 3 with ExtractFeaturesTask

use of org.dkpro.tc.core.task.ExtractFeaturesTask in project dkpro-tc by dkpro.

the class ExperimentTrainTest method init.

/**
 * Initializes the experiment. This is called automatically before execution. It's not done
 * directly in the constructor, because we want to be able to use setters instead of the
 * arguments in the constructor.
 */
@Override
protected void init() {
    if (experimentName == null) {
        throw new IllegalStateException("You must set an experiment name");
    }
    // init the train part of the experiment
    initTaskTrain = new InitTask();
    initTaskTrain.setPreprocessing(getPreprocessing());
    initTaskTrain.setOperativeViews(operativeViews);
    initTaskTrain.setTesting(false);
    initTaskTrain.setType(initTaskTrain.getType() + "-Train-" + experimentName);
    initTaskTrain.setAttribute(TC_TASK_TYPE, TcTaskType.INIT_TRAIN.toString());
    // init the test part of the experiment
    initTaskTest = new InitTask();
    initTaskTest.setTesting(true);
    initTaskTest.setPreprocessing(getPreprocessing());
    initTaskTest.setOperativeViews(operativeViews);
    initTaskTest.setType(initTaskTest.getType() + "-Test-" + experimentName);
    initTaskTest.setAttribute(TC_TASK_TYPE, TcTaskType.INIT_TEST.toString());
    collectionTask = new OutcomeCollectionTask();
    collectionTask.setType(collectionTask.getType() + "-" + experimentName);
    collectionTask.setAttribute(TC_TASK_TYPE, TcTaskType.COLLECTION.toString());
    collectionTask.addImport(initTaskTrain, InitTask.OUTPUT_KEY_TRAIN);
    collectionTask.addImport(initTaskTest, InitTask.OUTPUT_KEY_TEST);
    // get some meta data depending on the whole document collection that we need for training
    metaTask = new MetaInfoTask();
    metaTask.setOperativeViews(operativeViews);
    metaTask.setType(metaTask.getType() + "-" + experimentName);
    metaTask.addImport(initTaskTrain, InitTask.OUTPUT_KEY_TRAIN, MetaInfoTask.INPUT_KEY);
    metaTask.setAttribute(TC_TASK_TYPE, TcTaskType.META.toString());
    // feature extraction on training data
    featuresTrainTask = new ExtractFeaturesTask();
    featuresTrainTask.setType(featuresTrainTask.getType() + "-Train-" + experimentName);
    featuresTrainTask.setTesting(false);
    featuresTrainTask.addImport(metaTask, MetaInfoTask.META_KEY);
    featuresTrainTask.addImport(initTaskTrain, InitTask.OUTPUT_KEY_TRAIN, ExtractFeaturesTask.INPUT_KEY);
    featuresTrainTask.addImport(collectionTask, OutcomeCollectionTask.OUTPUT_KEY, ExtractFeaturesTask.COLLECTION_INPUT_KEY);
    featuresTrainTask.setAttribute(TC_TASK_TYPE, TcTaskType.FEATURE_EXTRACTION_TRAIN.toString());
    // feature extraction on test data
    featuresTestTask = new ExtractFeaturesTask();
    featuresTestTask.setType(featuresTestTask.getType() + "-Test-" + experimentName);
    featuresTestTask.setTesting(true);
    featuresTestTask.addImport(metaTask, MetaInfoTask.META_KEY);
    featuresTestTask.addImport(initTaskTest, InitTask.OUTPUT_KEY_TEST, ExtractFeaturesTask.INPUT_KEY);
    featuresTestTask.addImport(featuresTrainTask, ExtractFeaturesTask.OUTPUT_KEY);
    featuresTestTask.addImport(collectionTask, OutcomeCollectionTask.OUTPUT_KEY, ExtractFeaturesTask.COLLECTION_INPUT_KEY);
    featuresTestTask.setAttribute(TC_TASK_TYPE, TcTaskType.FEATURE_EXTRACTION_TEST.toString());
    // test task operating on the models of the feature extraction train and test tasks
    List<ReportBase> reports = new ArrayList<>();
    reports.add(new BasicResultReport());
    testTask = new DKProTcShallowTestTask(featuresTrainTask, featuresTestTask, collectionTask, reports, experimentName);
    testTask.setType(testTask.getType() + "-" + experimentName);
    testTask.setAttribute(TC_TASK_TYPE, TcTaskType.FACADE_TASK.toString());
    if (innerReports != null) {
        for (Class<? extends Report> report : innerReports) {
            testTask.addReport(report);
        }
    }
    testTask.addImport(featuresTrainTask, ExtractFeaturesTask.OUTPUT_KEY, Constants.TEST_TASK_INPUT_KEY_TRAINING_DATA);
    testTask.addImport(featuresTestTask, ExtractFeaturesTask.OUTPUT_KEY, Constants.TEST_TASK_INPUT_KEY_TEST_DATA);
    testTask.addImport(collectionTask, OutcomeCollectionTask.OUTPUT_KEY, Constants.OUTCOMES_INPUT_KEY);
    // DKPro Lab issue 38: must be added as *first* task
    addTask(initTaskTrain);
    addTask(initTaskTest);
    addTask(collectionTask);
    addTask(metaTask);
    addTask(featuresTrainTask);
    addTask(featuresTestTask);
    addTask(testTask);
}
Also used : DKProTcShallowTestTask(org.dkpro.tc.core.task.DKProTcShallowTestTask) ReportBase(org.dkpro.lab.reporting.ReportBase) BasicResultReport(org.dkpro.tc.ml.report.BasicResultReport) ArrayList(java.util.ArrayList) OutcomeCollectionTask(org.dkpro.tc.core.task.OutcomeCollectionTask) MetaInfoTask(org.dkpro.tc.core.task.MetaInfoTask) InitTask(org.dkpro.tc.core.task.InitTask) ExtractFeaturesTask(org.dkpro.tc.core.task.ExtractFeaturesTask)

Aggregations

ExtractFeaturesTask (org.dkpro.tc.core.task.ExtractFeaturesTask)3 InitTask (org.dkpro.tc.core.task.InitTask)3 MetaInfoTask (org.dkpro.tc.core.task.MetaInfoTask)3 OutcomeCollectionTask (org.dkpro.tc.core.task.OutcomeCollectionTask)3 ArrayList (java.util.ArrayList)2 ReportBase (org.dkpro.lab.reporting.ReportBase)2 TextClassificationException (org.dkpro.tc.api.exception.TextClassificationException)2 DKProTcShallowTestTask (org.dkpro.tc.core.task.DKProTcShallowTestTask)2 BasicResultReport (org.dkpro.tc.ml.report.BasicResultReport)2 File (java.io.File)1 TaskContext (org.dkpro.lab.engine.TaskContext)1 ParameterSpace (org.dkpro.lab.task.ParameterSpace)1 DefaultBatchTask (org.dkpro.lab.task.impl.DefaultBatchTask)1 DKProTcShallowSerializationTask (org.dkpro.tc.core.task.DKProTcShallowSerializationTask)1