use of org.dkpro.lab.task.impl.DefaultBatchTask in project dkpro-lab by dkpro.
the class MultiThreadBatchTaskTest method testNested2.
@Test
public void testNested2() throws Exception {
// BatchTask innerTask = new BatchTask()
DefaultBatchTask innerTask = new DefaultBatchTask() {
@Discriminator
private Integer outer;
@Override
public ParameterSpace getParameterSpace() {
// Dynamically configure parameter space of nested batch task
Integer[] values = new Integer[outer];
for (int i = 0; i < outer; i++) {
values[i] = i;
}
Dimension<Integer> innerDim = Dimension.create("inner", values);
ParameterSpace innerPSpace = new ParameterSpace(innerDim);
return innerPSpace;
}
@Override
public void setConfiguration(Map<String, Object> aConfig) {
super.setConfiguration(aConfig);
System.out.printf("A %10d %s %s%n", this.hashCode(), getType(), aConfig);
}
};
innerTask.addTask(new ConfigDumperTask1());
Dimension<Integer> outerDim = Dimension.create("outer", 1, 2, 3);
ParameterSpace outerPSpace = new ParameterSpace(outerDim);
DefaultBatchTask outerTask = new DefaultBatchTask() {
@Override
public void setConfiguration(Map<String, Object> aConfig) {
super.setConfiguration(aConfig);
System.out.printf("B %10d %s %s%n", this.hashCode(), getType(), aConfig);
}
};
outerTask.setParameterSpace(outerPSpace);
outerTask.addTask(innerTask);
outerTask.addTask(new ConfigDumperTask2());
Lab.getInstance().run(outerTask);
}
use of org.dkpro.lab.task.impl.DefaultBatchTask in project dkpro-lab by dkpro.
the class PosExampleCrf method run.
@Test
public void run() throws Exception {
// Route logging through log4j
System.setProperty("org.apache.uima.logger.class", "org.apache.uima.util.impl.Log4jLogger_impl");
clean();
Task preprocessingTask = new UimaTaskBase() {
@Discriminator
String corpusPath;
{
setType("Preprocessing");
}
@Override
public CollectionReaderDescription getCollectionReaderDescription(TaskContext aContext) throws ResourceInitializationException, IOException {
return createReader(NegraExportReader.class, NegraExportReader.PARAM_SOURCE_LOCATION, corpusPath, NegraExportReader.PARAM_LANGUAGE, "de");
}
@Override
public AnalysisEngineDescription getAnalysisEngineDescription(TaskContext aContext) throws ResourceInitializationException, IOException {
File xmiDir = aContext.getFolder("XMI", AccessMode.READWRITE);
return createEngine(createEngine(SnowballStemmer.class), createEngine(XmiWriter.class, XmiWriter.PARAM_TARGET_LOCATION, xmiDir.getAbsolutePath(), XmiWriter.PARAM_COMPRESSION, CompressionMethod.GZIP));
}
};
Task featureExtractionTask = new UimaTaskBase() {
{
setType("FeatureExtraction");
}
@Override
public CollectionReaderDescription getCollectionReaderDescription(TaskContext aContext) throws ResourceInitializationException, IOException {
File xmiDir = aContext.getFolder("XMI", AccessMode.READONLY);
return createReader(XmiReader.class, XmiReader.PARAM_SOURCE_LOCATION, xmiDir.getAbsolutePath(), XmiReader.PARAM_PATTERNS, new String[] { "[+]**/*.xmi.gz" });
}
@Override
public AnalysisEngineDescription getAnalysisEngineDescription(TaskContext aContext) throws ResourceInitializationException, IOException {
File modelDir = aContext.getFolder("MODEL", AccessMode.READWRITE);
return createEngine(createEngineDescription(ExamplePosAnnotator.class, ExamplePosAnnotator.PARAM_DATA_WRITER_FACTORY_CLASS_NAME, DefaultMalletCRFDataWriterFactory.class.getName(), DefaultMalletCRFDataWriterFactory.PARAM_OUTPUT_DIRECTORY, modelDir.getAbsolutePath()));
}
};
Task trainingTask = new ExecutableTaskBase() {
{
setType("TrainingTask");
}
@Override
public void execute(TaskContext aContext) throws Exception {
File dir = aContext.getFolder("MODEL", AccessMode.READWRITE);
JarClassifierBuilder<?> classifierBuilder = JarClassifierBuilder.fromTrainingDirectory(dir);
classifierBuilder.trainClassifier(dir, new String[0]);
classifierBuilder.packageClassifier(dir);
}
};
Task analysisTask = new UimaTaskBase() {
{
setType("AnalysisTask");
}
@Override
public CollectionReaderDescription getCollectionReaderDescription(TaskContext aContext) throws ResourceInitializationException, IOException {
return createReaderDescription(TextReader.class, TextReader.PARAM_SOURCE_LOCATION, "src/test/resources/text", TextReader.PARAM_PATTERNS, new String[] { "[+]**/*.txt" }, TextReader.PARAM_LANGUAGE, "de");
}
@Override
public AnalysisEngineDescription getAnalysisEngineDescription(TaskContext aContext) throws ResourceInitializationException, IOException {
File model = new File(aContext.getFolder("MODEL", AccessMode.READONLY), "model.jar");
File tsv = new File(aContext.getFolder("TSV", AccessMode.READWRITE), "output.tsv");
return createEngine(createEngineDescription(BreakIteratorSegmenter.class), createEngineDescription(SnowballStemmer.class), createEngineDescription(ExamplePosAnnotator.class, GenericJarClassifierFactory.PARAM_CLASSIFIER_JAR_PATH, model.getAbsolutePath()), createEngineDescription(ImsCwbWriter.class, ImsCwbWriter.PARAM_TARGET_LOCATION, tsv));
}
};
ParameterSpace pSpace = new ParameterSpace(Dimension.create("corpusPath", CORPUS_PATH));
featureExtractionTask.addImport(preprocessingTask, "XMI");
trainingTask.addImport(featureExtractionTask, "MODEL");
analysisTask.addImport(trainingTask, "MODEL");
DefaultBatchTask batch = new DefaultBatchTask();
batch.setParameterSpace(pSpace);
batch.setExecutionPolicy(ExecutionPolicy.USE_EXISTING);
batch.addTask(preprocessingTask);
batch.addTask(featureExtractionTask);
batch.addTask(trainingTask);
batch.addTask(analysisTask);
Lab.getInstance().run(batch);
}
use of org.dkpro.lab.task.impl.DefaultBatchTask in project dkpro-lab by dkpro.
the class PosExampleMaxEnt method run.
@Test
public void run() throws Exception {
// Route logging through log4j
System.setProperty("org.apache.uima.logger.class", "org.apache.uima.util.impl.Log4jLogger_impl");
clean();
Task preprocessingTask = new UimaTaskBase() {
@Discriminator
String corpusPath;
{
setType("Preprocessing");
}
@Override
public CollectionReaderDescription getCollectionReaderDescription(TaskContext aContext) throws ResourceInitializationException, IOException {
return createReader(NegraExportReader.class, NegraExportReader.PARAM_SOURCE_LOCATION, corpusPath, NegraExportReader.PARAM_LANGUAGE, "de");
}
@Override
public AnalysisEngineDescription getAnalysisEngineDescription(TaskContext aContext) throws ResourceInitializationException, IOException {
File xmiDir = aContext.getFolder("XMI", AccessMode.READWRITE);
return createEngine(createEngine(SnowballStemmer.class), createEngine(XmiWriter.class, XmiWriter.PARAM_TARGET_LOCATION, xmiDir.getAbsolutePath(), XmiWriter.PARAM_COMPRESSION, CompressionMethod.GZIP));
}
};
Task featureExtractionTask = new UimaTaskBase() {
{
setType("FeatureExtraction");
}
@Override
public CollectionReaderDescription getCollectionReaderDescription(TaskContext aContext) throws ResourceInitializationException, IOException {
File xmiDir = aContext.getFolder("XMI", AccessMode.READONLY);
return createReader(XmiReader.class, XmiReader.PARAM_SOURCE_LOCATION, xmiDir.getAbsolutePath(), XmiReader.PARAM_PATTERNS, new String[] { "[+]**/*.xmi.gz" });
}
@Override
public AnalysisEngineDescription getAnalysisEngineDescription(TaskContext aContext) throws ResourceInitializationException, IOException {
File modelDir = aContext.getFolder("MODEL", AccessMode.READWRITE);
return createEngine(createEngineDescription(ExamplePosAnnotator.class, ExamplePosAnnotator.PARAM_DATA_WRITER_FACTORY_CLASS_NAME, ViterbiDataWriterFactory.class.getName(), ViterbiDataWriterFactory.PARAM_OUTPUT_DIRECTORY, modelDir.getAbsolutePath(), ViterbiDataWriterFactory.PARAM_DELEGATED_DATA_WRITER_FACTORY_CLASS, DefaultMaxentDataWriterFactory.class.getName()));
}
};
Task trainingTask = new ExecutableTaskBase() {
@Discriminator
private int iterations;
@Discriminator
private int cutoff;
{
setType("TrainingTask");
}
@Override
public void execute(TaskContext aContext) throws Exception {
File dir = aContext.getFolder("MODEL", AccessMode.READWRITE);
JarClassifierBuilder<?> classifierBuilder = JarClassifierBuilder.fromTrainingDirectory(dir);
classifierBuilder.trainClassifier(dir, new String[] { String.valueOf(iterations), String.valueOf(cutoff) });
classifierBuilder.packageClassifier(dir);
}
};
Task analysisTask = new UimaTaskBase() {
{
setType("AnalysisTask");
}
@Override
public CollectionReaderDescription getCollectionReaderDescription(TaskContext aContext) throws ResourceInitializationException, IOException {
return createReaderDescription(TextReader.class, TextReader.PARAM_SOURCE_LOCATION, "src/test/resources/text/**/*.txt", TextReader.PARAM_LANGUAGE, "de");
}
@Override
public AnalysisEngineDescription getAnalysisEngineDescription(TaskContext aContext) throws ResourceInitializationException, IOException {
File model = new File(aContext.getFolder("MODEL", AccessMode.READONLY), "model.jar");
File tsv = new File(aContext.getFolder("TSV", AccessMode.READWRITE), "output.tsv");
return createEngine(createEngineDescription(BreakIteratorSegmenter.class), createEngineDescription(SnowballStemmer.class), createEngineDescription(ExamplePosAnnotator.class, GenericJarClassifierFactory.PARAM_CLASSIFIER_JAR_PATH, model.getAbsolutePath()), createEngineDescription(ImsCwbWriter.class, ImsCwbWriter.PARAM_TARGET_LOCATION, tsv));
}
};
ParameterSpace pSpace = new ParameterSpace(Dimension.create("corpusPath", CORPUS_PATH), Dimension.create("iterations", 20, 50, 100), Dimension.create("cutoff", 5));
featureExtractionTask.addImport(preprocessingTask, "XMI");
trainingTask.addImport(featureExtractionTask, "MODEL");
analysisTask.addImport(trainingTask, "MODEL");
DefaultBatchTask batch = new DefaultBatchTask();
batch.setParameterSpace(pSpace);
batch.setExecutionPolicy(ExecutionPolicy.USE_EXISTING);
batch.addTask(preprocessingTask);
batch.addTask(featureExtractionTask);
batch.addTask(trainingTask);
batch.addTask(analysisTask);
Lab.getInstance().run(batch);
}
use of org.dkpro.lab.task.impl.DefaultBatchTask in project dkpro-tc by dkpro.
the class ExperimentCrossValidation method init.
/**
* Initializes the experiment. This is called automatically before execution. It's not done
* directly in the constructor, because we want to be able to use setters instead of the
* three-argument constructor.
*/
protected void init() throws IllegalStateException {
if (experimentName == null) {
throw new IllegalStateException("You must set an experiment name");
}
if (numFolds < 2) {
throw new IllegalStateException("Number of folds is not configured correctly. Number of folds needs to be at " + "least 2 (but was " + numFolds + ")");
}
// initialize the setup
initTask = new InitTask();
initTask.setPreprocessing(getPreprocessing());
initTask.setOperativeViews(operativeViews);
initTask.setType(initTask.getType() + "-" + experimentName);
initTask.setAttribute(TC_TASK_TYPE, TcTaskType.INIT_TRAIN.toString());
// inner batch task (carried out numFolds times)
DefaultBatchTask crossValidationTask = new DefaultBatchTask() {
@Discriminator(name = DIM_FEATURE_MODE)
private String featureMode;
@Discriminator(name = DIM_CROSS_VALIDATION_MANUAL_FOLDS)
private boolean useCrossValidationManualFolds;
@Override
public void initialize(TaskContext aContext) {
super.initialize(aContext);
File xmiPathRoot = aContext.getFolder(InitTask.OUTPUT_KEY_TRAIN, AccessMode.READONLY);
Collection<File> files = FileUtils.listFiles(xmiPathRoot, new String[] { "bin" }, true);
String[] fileNames = new String[files.size()];
int i = 0;
for (File f : files) {
// adding file paths, not names
fileNames[i] = f.getAbsolutePath();
i++;
}
Arrays.sort(fileNames);
if (numFolds == LEAVE_ONE_OUT) {
numFolds = fileNames.length;
}
// off
if (!useCrossValidationManualFolds && fileNames.length < numFolds) {
xmiPathRoot = createRequestedNumberOfCas(xmiPathRoot, fileNames.length, featureMode);
files = FileUtils.listFiles(xmiPathRoot, new String[] { "bin" }, true);
fileNames = new String[files.size()];
i = 0;
for (File f : files) {
// adding file paths, not names
fileNames[i] = f.getAbsolutePath();
i++;
}
}
// don't change any names!!
FoldDimensionBundle<String> foldDim = getFoldDim(fileNames);
Dimension<File> filesRootDim = Dimension.create(DIM_FILES_ROOT, xmiPathRoot);
ParameterSpace pSpace = new ParameterSpace(foldDim, filesRootDim);
setParameterSpace(pSpace);
}
/**
* creates required number of CAS
*
* @param xmiPathRoot
* input path
* @param numAvailableJCas
* all CAS
* @param featureMode
* the feature mode
* @return a file
*/
private File createRequestedNumberOfCas(File xmiPathRoot, int numAvailableJCas, String featureMode) {
try {
File outputFolder = FoldUtil.createMinimalSplit(xmiPathRoot.getAbsolutePath(), numFolds, numAvailableJCas, FM_SEQUENCE.equals(featureMode));
if (outputFolder == null) {
throw new NullPointerException("Output folder is null");
}
verfiyThatNeededNumberOfCasWasCreated(outputFolder);
return outputFolder;
} catch (Exception e) {
throw new IllegalStateException(e);
}
}
private void verfiyThatNeededNumberOfCasWasCreated(File outputFolder) {
int numCas = 0;
File[] listFiles = outputFolder.listFiles();
if (listFiles == null) {
throw new NullPointerException("Retrieving files in folder led to a NullPointer");
}
for (File f : listFiles) {
if (f.getName().contains(".bin")) {
numCas++;
}
}
if (numCas < numFolds) {
throw new IllegalStateException("Not enough TextClassificationUnits found to create at least [" + numFolds + "] folds");
}
}
};
// ================== SUBTASKS OF THE INNER BATCH TASK =======================
// collecting meta features only on the training data (numFolds times)
collectionTask = new OutcomeCollectionTask();
collectionTask.setType(collectionTask.getType() + "-" + experimentName);
collectionTask.setAttribute(TC_TASK_TYPE, TcTaskType.COLLECTION.toString());
collectionTask.addImport(initTask, InitTask.OUTPUT_KEY_TRAIN);
metaTask = new MetaInfoTask();
metaTask.setOperativeViews(operativeViews);
metaTask.setType(metaTask.getType() + "-" + experimentName);
metaTask.setAttribute(TC_TASK_TYPE, TcTaskType.META.toString());
// extracting features from training data (numFolds times)
extractFeaturesTrainTask = new ExtractFeaturesTask();
extractFeaturesTrainTask.setTesting(false);
extractFeaturesTrainTask.setType(extractFeaturesTrainTask.getType() + "-Train-" + experimentName);
extractFeaturesTrainTask.addImport(metaTask, MetaInfoTask.META_KEY);
extractFeaturesTrainTask.addImport(initTask, InitTask.OUTPUT_KEY_TRAIN, ExtractFeaturesTask.INPUT_KEY);
extractFeaturesTrainTask.addImport(collectionTask, OutcomeCollectionTask.OUTPUT_KEY, ExtractFeaturesTask.COLLECTION_INPUT_KEY);
extractFeaturesTrainTask.setAttribute(TC_TASK_TYPE, TcTaskType.FEATURE_EXTRACTION_TRAIN.toString());
// extracting features from test data (numFolds times)
extractFeaturesTestTask = new ExtractFeaturesTask();
extractFeaturesTestTask.setTesting(true);
extractFeaturesTestTask.setType(extractFeaturesTestTask.getType() + "-Test-" + experimentName);
extractFeaturesTestTask.addImport(metaTask, MetaInfoTask.META_KEY);
extractFeaturesTestTask.addImport(extractFeaturesTrainTask, ExtractFeaturesTask.OUTPUT_KEY);
extractFeaturesTestTask.addImport(initTask, InitTask.OUTPUT_KEY_TRAIN, ExtractFeaturesTask.INPUT_KEY);
extractFeaturesTestTask.addImport(collectionTask, OutcomeCollectionTask.OUTPUT_KEY, ExtractFeaturesTask.COLLECTION_INPUT_KEY);
extractFeaturesTestTask.setAttribute(TC_TASK_TYPE, TcTaskType.FEATURE_EXTRACTION_TEST.toString());
// test task operating on the models of the feature extraction train and test tasks
List<ReportBase> reports = new ArrayList<>();
reports.add(new BasicResultReport());
testTask = new DKProTcShallowTestTask(extractFeaturesTrainTask, extractFeaturesTestTask, collectionTask, reports, experimentName);
testTask.setType(testTask.getType() + "-" + experimentName);
testTask.setAttribute(TC_TASK_TYPE, TcTaskType.FACADE_TASK.toString());
if (innerReports != null) {
for (Class<? extends Report> report : innerReports) {
testTask.addReport(report);
}
}
testTask.addImport(extractFeaturesTrainTask, ExtractFeaturesTask.OUTPUT_KEY, TEST_TASK_INPUT_KEY_TRAINING_DATA);
testTask.addImport(extractFeaturesTestTask, ExtractFeaturesTask.OUTPUT_KEY, TEST_TASK_INPUT_KEY_TEST_DATA);
testTask.addImport(collectionTask, OutcomeCollectionTask.OUTPUT_KEY, Constants.OUTCOMES_INPUT_KEY);
// ================== CONFIG OF THE INNER BATCH TASK =======================
crossValidationTask.addImport(initTask, InitTask.OUTPUT_KEY_TRAIN);
crossValidationTask.setType(crossValidationTask.getType() + "-" + experimentName);
crossValidationTask.addTask(collectionTask);
crossValidationTask.addTask(metaTask);
crossValidationTask.addTask(extractFeaturesTrainTask);
crossValidationTask.addTask(extractFeaturesTestTask);
crossValidationTask.addTask(testTask);
crossValidationTask.setExecutionPolicy(ExecutionPolicy.USE_EXISTING);
// report of the inner batch task (sums up results for the folds)
// we want to re-use the old CV report, we need to collect the evaluation.bin files from
// the test task here (with another report)
crossValidationTask.addReport(InnerBatchReport.class);
crossValidationTask.setAttribute(TC_TASK_TYPE, TcTaskType.CROSS_VALIDATION.toString());
// DKPro Lab issue 38: must be added as *first* task
addTask(initTask);
addTask(crossValidationTask);
}
use of org.dkpro.lab.task.impl.DefaultBatchTask in project dkpro-tc by dkpro.
the class DeepLearningExperimentCrossValidation method init.
/**
* Initializes the experiment. This is called automatically before execution. It's not done
* directly in the constructor, because we want to be able to use setters instead of the
* three-argument constructor.
*
* @throws IllegalStateException
* in case of errors
*/
protected void init() throws IllegalStateException {
if (experimentName == null) {
throw new IllegalStateException("You must set an experiment name");
}
if (numFolds < 2) {
throw new IllegalStateException("Number of folds is not configured correctly. Number of folds needs to be at " + "least 2 (but was " + numFolds + ")");
}
// initialize the setup
initTask = new InitTaskDeep();
initTask.setPreprocessing(getPreprocessing());
initTask.setOperativeViews(operativeViews);
initTask.setType(initTask.getType() + "-" + experimentName);
initTask.setAttribute(TC_TASK_TYPE, TcTaskType.INIT_TRAIN.toString());
// inner batch task (carried out numFolds times)
DefaultBatchTask crossValidationTask = new DefaultBatchTask() {
@Discriminator(name = DIM_FEATURE_MODE)
private String featureMode;
@Discriminator(name = DIM_CROSS_VALIDATION_MANUAL_FOLDS)
private boolean useCrossValidationManualFolds;
@Override
public void initialize(TaskContext aContext) {
super.initialize(aContext);
File xmiPathRoot = aContext.getFolder(InitTask.OUTPUT_KEY_TRAIN, AccessMode.READONLY);
Collection<File> files = FileUtils.listFiles(xmiPathRoot, new String[] { "bin" }, true);
String[] fileNames = new String[files.size()];
int i = 0;
for (File f : files) {
// adding file paths, not names
fileNames[i] = f.getAbsolutePath();
i++;
}
Arrays.sort(fileNames);
if (numFolds == LEAVE_ONE_OUT) {
numFolds = fileNames.length;
}
// manual mode is turned off
if (!useCrossValidationManualFolds && fileNames.length < numFolds) {
xmiPathRoot = createRequestedNumberOfCas(xmiPathRoot, fileNames.length, featureMode);
files = FileUtils.listFiles(xmiPathRoot, new String[] { "bin" }, true);
fileNames = new String[files.size()];
i = 0;
for (File f : files) {
// adding file paths, not names
fileNames[i] = f.getAbsolutePath();
i++;
}
}
// don't change any names!!
FoldDimensionBundle<String> foldDim = getFoldDim(fileNames);
Dimension<File> filesRootDim = Dimension.create(DIM_FILES_ROOT, xmiPathRoot);
ParameterSpace pSpace = new ParameterSpace(foldDim, filesRootDim);
setParameterSpace(pSpace);
}
/**
* creates required number of CAS
*
* @param xmiPathRoot
* input path
* @param numAvailableJCas
* all CAS
* @param featureMode
* the feature mode
* @return a file
*/
private File createRequestedNumberOfCas(File xmiPathRoot, int numAvailableJCas, String featureMode) {
try {
File outputFolder = FoldUtil.createMinimalSplit(xmiPathRoot.getAbsolutePath(), numFolds, numAvailableJCas, FM_SEQUENCE.equals(featureMode));
if (outputFolder == null) {
throw new NullPointerException("Output folder is null");
}
verfiyThatNeededNumberOfCasWasCreated(outputFolder);
return outputFolder;
} catch (Exception e) {
throw new IllegalStateException(e);
}
}
private void verfiyThatNeededNumberOfCasWasCreated(File outputFolder) {
int numCas = 0;
File[] listFiles = outputFolder.listFiles();
if (listFiles == null) {
throw new NullPointerException("Retrieving files in folder led to a NullPointer");
}
for (File f : listFiles) {
if (f.getName().contains(".bin")) {
numCas++;
}
}
if (numCas < numFolds) {
throw new IllegalStateException("Not enough TextClassificationUnits found to create at least [" + numFolds + "] folds");
}
}
};
// ================== SUBTASKS OF THE INNER BATCH TASK
// =======================
// collecting meta features only on the training data (numFolds times)
// get some meta data depending on the whole document collection
preparationTask = new PreparationTask();
preparationTask.setType(preparationTask.getType() + "-" + experimentName);
preparationTask.setMachineLearningAdapter(mlAdapter);
preparationTask.addImport(initTask, InitTask.OUTPUT_KEY_TRAIN, PreparationTask.INPUT_KEY_TRAIN);
preparationTask.setAttribute(TC_TASK_TYPE, TcTaskType.PREPARATION.toString());
embeddingTask = new EmbeddingTask();
embeddingTask.setType(embeddingTask.getType() + "-" + experimentName);
embeddingTask.addImport(preparationTask, PreparationTask.OUTPUT_KEY, EmbeddingTask.INPUT_MAPPING);
embeddingTask.setAttribute(TC_TASK_TYPE, TcTaskType.EMBEDDING.toString());
// feature extraction on training data
vectorizationTrainTask = new VectorizationTask();
vectorizationTrainTask.setType(vectorizationTrainTask.getType() + "-Train-" + experimentName);
vectorizationTrainTask.setTesting(false);
vectorizationTrainTask.addImport(preparationTask, PreparationTask.OUTPUT_KEY, VectorizationTask.MAPPING_INPUT_KEY);
vectorizationTrainTask.setAttribute(TC_TASK_TYPE, TcTaskType.VECTORIZATION_TRAIN.toString());
// feature extraction on test data
vectorizationTestTask = new VectorizationTask();
vectorizationTestTask.setType(vectorizationTestTask.getType() + "-Test-" + experimentName);
vectorizationTestTask.setTesting(true);
vectorizationTestTask.addImport(preparationTask, PreparationTask.OUTPUT_KEY, VectorizationTask.MAPPING_INPUT_KEY);
vectorizationTrainTask.setAttribute(TC_TASK_TYPE, TcTaskType.VECTORIZATION_TEST.toString());
// test task operating on the models of the feature extraction train and
// test tasks
learningTask = mlAdapter.getTestTask();
learningTask.setType(learningTask.getType() + "-" + experimentName);
learningTask.setAttribute(TC_TASK_TYPE, TcTaskType.MACHINE_LEARNING_ADAPTER.toString());
if (innerReports != null) {
for (Class<? extends Report> report : innerReports) {
learningTask.addReport(report);
}
}
// // always add OutcomeIdReport
learningTask.addReport(mlAdapter.getOutcomeIdReportClass());
learningTask.addReport(mlAdapter.getMajorityBaselineIdReportClass());
learningTask.addReport(mlAdapter.getRandomBaselineIdReportClass());
learningTask.addReport(mlAdapter.getMetaCollectionReport());
learningTask.addReport(BasicResultReport.class);
learningTask.addImport(preparationTask, PreparationTask.OUTPUT_KEY, TcDeepLearningAdapter.PREPARATION_FOLDER);
learningTask.addImport(vectorizationTrainTask, VectorizationTask.OUTPUT_KEY, Constants.TEST_TASK_INPUT_KEY_TRAINING_DATA);
learningTask.addImport(vectorizationTestTask, VectorizationTask.OUTPUT_KEY, Constants.TEST_TASK_INPUT_KEY_TEST_DATA);
learningTask.addImport(embeddingTask, EmbeddingTask.OUTPUT_KEY, TcDeepLearningAdapter.EMBEDDING_FOLDER);
learningTask.addImport(vectorizationTrainTask, VectorizationTask.OUTPUT_KEY, TcDeepLearningAdapter.VECTORIZIATION_TRAIN_OUTPUT);
learningTask.addImport(vectorizationTrainTask, VectorizationTask.OUTPUT_KEY, TcDeepLearningAdapter.TARGET_ID_MAPPING_TRAIN);
learningTask.addImport(vectorizationTestTask, VectorizationTask.OUTPUT_KEY, TcDeepLearningAdapter.VECTORIZIATION_TEST_OUTPUT);
learningTask.addImport(vectorizationTestTask, VectorizationTask.OUTPUT_KEY, TcDeepLearningAdapter.TARGET_ID_MAPPING_TEST);
// ================== CONFIG OF THE INNER BATCH TASK
// =======================
crossValidationTask.addImport(initTask, InitTask.OUTPUT_KEY_TRAIN);
crossValidationTask.setType(crossValidationTask.getType() + "-" + experimentName);
crossValidationTask.addTask(preparationTask);
crossValidationTask.addTask(embeddingTask);
crossValidationTask.addTask(vectorizationTrainTask);
crossValidationTask.addTask(vectorizationTestTask);
crossValidationTask.addTask(learningTask);
crossValidationTask.setExecutionPolicy(ExecutionPolicy.USE_EXISTING);
// report of the inner batch task (sums up results for the folds)
// we want to re-use the old CV report, we need to collect the
// evaluation.bin files from
// the test task here (with another report)
crossValidationTask.addReport(DeepLearningInnerBatchReport.class);
crossValidationTask.setAttribute(TC_TASK_TYPE, TcTaskType.CROSS_VALIDATION.toString());
// DKPro Lab issue 38: must be added as *first* task
addTask(initTask);
addTask(crossValidationTask);
}
Aggregations