use of org.dkpro.lab.task.ParameterSpace in project dkpro-lab by dkpro.
the class PosExampleCrf method run.
@Test
public void run() throws Exception {
// Route logging through log4j
System.setProperty("org.apache.uima.logger.class", "org.apache.uima.util.impl.Log4jLogger_impl");
clean();
Task preprocessingTask = new UimaTaskBase() {
@Discriminator
String corpusPath;
{
setType("Preprocessing");
}
@Override
public CollectionReaderDescription getCollectionReaderDescription(TaskContext aContext) throws ResourceInitializationException, IOException {
return createReader(NegraExportReader.class, NegraExportReader.PARAM_SOURCE_LOCATION, corpusPath, NegraExportReader.PARAM_LANGUAGE, "de");
}
@Override
public AnalysisEngineDescription getAnalysisEngineDescription(TaskContext aContext) throws ResourceInitializationException, IOException {
File xmiDir = aContext.getFolder("XMI", AccessMode.READWRITE);
return createEngine(createEngine(SnowballStemmer.class), createEngine(XmiWriter.class, XmiWriter.PARAM_TARGET_LOCATION, xmiDir.getAbsolutePath(), XmiWriter.PARAM_COMPRESSION, CompressionMethod.GZIP));
}
};
Task featureExtractionTask = new UimaTaskBase() {
{
setType("FeatureExtraction");
}
@Override
public CollectionReaderDescription getCollectionReaderDescription(TaskContext aContext) throws ResourceInitializationException, IOException {
File xmiDir = aContext.getFolder("XMI", AccessMode.READONLY);
return createReader(XmiReader.class, XmiReader.PARAM_SOURCE_LOCATION, xmiDir.getAbsolutePath(), XmiReader.PARAM_PATTERNS, new String[] { "[+]**/*.xmi.gz" });
}
@Override
public AnalysisEngineDescription getAnalysisEngineDescription(TaskContext aContext) throws ResourceInitializationException, IOException {
File modelDir = aContext.getFolder("MODEL", AccessMode.READWRITE);
return createEngine(createEngineDescription(ExamplePosAnnotator.class, ExamplePosAnnotator.PARAM_DATA_WRITER_FACTORY_CLASS_NAME, DefaultMalletCRFDataWriterFactory.class.getName(), DefaultMalletCRFDataWriterFactory.PARAM_OUTPUT_DIRECTORY, modelDir.getAbsolutePath()));
}
};
Task trainingTask = new ExecutableTaskBase() {
{
setType("TrainingTask");
}
@Override
public void execute(TaskContext aContext) throws Exception {
File dir = aContext.getFolder("MODEL", AccessMode.READWRITE);
JarClassifierBuilder<?> classifierBuilder = JarClassifierBuilder.fromTrainingDirectory(dir);
classifierBuilder.trainClassifier(dir, new String[0]);
classifierBuilder.packageClassifier(dir);
}
};
Task analysisTask = new UimaTaskBase() {
{
setType("AnalysisTask");
}
@Override
public CollectionReaderDescription getCollectionReaderDescription(TaskContext aContext) throws ResourceInitializationException, IOException {
return createReaderDescription(TextReader.class, TextReader.PARAM_SOURCE_LOCATION, "src/test/resources/text", TextReader.PARAM_PATTERNS, new String[] { "[+]**/*.txt" }, TextReader.PARAM_LANGUAGE, "de");
}
@Override
public AnalysisEngineDescription getAnalysisEngineDescription(TaskContext aContext) throws ResourceInitializationException, IOException {
File model = new File(aContext.getFolder("MODEL", AccessMode.READONLY), "model.jar");
File tsv = new File(aContext.getFolder("TSV", AccessMode.READWRITE), "output.tsv");
return createEngine(createEngineDescription(BreakIteratorSegmenter.class), createEngineDescription(SnowballStemmer.class), createEngineDescription(ExamplePosAnnotator.class, GenericJarClassifierFactory.PARAM_CLASSIFIER_JAR_PATH, model.getAbsolutePath()), createEngineDescription(ImsCwbWriter.class, ImsCwbWriter.PARAM_TARGET_LOCATION, tsv));
}
};
ParameterSpace pSpace = new ParameterSpace(Dimension.create("corpusPath", CORPUS_PATH));
featureExtractionTask.addImport(preprocessingTask, "XMI");
trainingTask.addImport(featureExtractionTask, "MODEL");
analysisTask.addImport(trainingTask, "MODEL");
DefaultBatchTask batch = new DefaultBatchTask();
batch.setParameterSpace(pSpace);
batch.setExecutionPolicy(ExecutionPolicy.USE_EXISTING);
batch.addTask(preprocessingTask);
batch.addTask(featureExtractionTask);
batch.addTask(trainingTask);
batch.addTask(analysisTask);
Lab.getInstance().run(batch);
}
use of org.dkpro.lab.task.ParameterSpace in project dkpro-lab by dkpro.
the class PosExampleMaxEnt method run.
@Test
public void run() throws Exception {
// Route logging through log4j
System.setProperty("org.apache.uima.logger.class", "org.apache.uima.util.impl.Log4jLogger_impl");
clean();
Task preprocessingTask = new UimaTaskBase() {
@Discriminator
String corpusPath;
{
setType("Preprocessing");
}
@Override
public CollectionReaderDescription getCollectionReaderDescription(TaskContext aContext) throws ResourceInitializationException, IOException {
return createReader(NegraExportReader.class, NegraExportReader.PARAM_SOURCE_LOCATION, corpusPath, NegraExportReader.PARAM_LANGUAGE, "de");
}
@Override
public AnalysisEngineDescription getAnalysisEngineDescription(TaskContext aContext) throws ResourceInitializationException, IOException {
File xmiDir = aContext.getFolder("XMI", AccessMode.READWRITE);
return createEngine(createEngine(SnowballStemmer.class), createEngine(XmiWriter.class, XmiWriter.PARAM_TARGET_LOCATION, xmiDir.getAbsolutePath(), XmiWriter.PARAM_COMPRESSION, CompressionMethod.GZIP));
}
};
Task featureExtractionTask = new UimaTaskBase() {
{
setType("FeatureExtraction");
}
@Override
public CollectionReaderDescription getCollectionReaderDescription(TaskContext aContext) throws ResourceInitializationException, IOException {
File xmiDir = aContext.getFolder("XMI", AccessMode.READONLY);
return createReader(XmiReader.class, XmiReader.PARAM_SOURCE_LOCATION, xmiDir.getAbsolutePath(), XmiReader.PARAM_PATTERNS, new String[] { "[+]**/*.xmi.gz" });
}
@Override
public AnalysisEngineDescription getAnalysisEngineDescription(TaskContext aContext) throws ResourceInitializationException, IOException {
File modelDir = aContext.getFolder("MODEL", AccessMode.READWRITE);
return createEngine(createEngineDescription(ExamplePosAnnotator.class, ExamplePosAnnotator.PARAM_DATA_WRITER_FACTORY_CLASS_NAME, ViterbiDataWriterFactory.class.getName(), ViterbiDataWriterFactory.PARAM_OUTPUT_DIRECTORY, modelDir.getAbsolutePath(), ViterbiDataWriterFactory.PARAM_DELEGATED_DATA_WRITER_FACTORY_CLASS, DefaultMaxentDataWriterFactory.class.getName()));
}
};
Task trainingTask = new ExecutableTaskBase() {
@Discriminator
private int iterations;
@Discriminator
private int cutoff;
{
setType("TrainingTask");
}
@Override
public void execute(TaskContext aContext) throws Exception {
File dir = aContext.getFolder("MODEL", AccessMode.READWRITE);
JarClassifierBuilder<?> classifierBuilder = JarClassifierBuilder.fromTrainingDirectory(dir);
classifierBuilder.trainClassifier(dir, new String[] { String.valueOf(iterations), String.valueOf(cutoff) });
classifierBuilder.packageClassifier(dir);
}
};
Task analysisTask = new UimaTaskBase() {
{
setType("AnalysisTask");
}
@Override
public CollectionReaderDescription getCollectionReaderDescription(TaskContext aContext) throws ResourceInitializationException, IOException {
return createReaderDescription(TextReader.class, TextReader.PARAM_SOURCE_LOCATION, "src/test/resources/text/**/*.txt", TextReader.PARAM_LANGUAGE, "de");
}
@Override
public AnalysisEngineDescription getAnalysisEngineDescription(TaskContext aContext) throws ResourceInitializationException, IOException {
File model = new File(aContext.getFolder("MODEL", AccessMode.READONLY), "model.jar");
File tsv = new File(aContext.getFolder("TSV", AccessMode.READWRITE), "output.tsv");
return createEngine(createEngineDescription(BreakIteratorSegmenter.class), createEngineDescription(SnowballStemmer.class), createEngineDescription(ExamplePosAnnotator.class, GenericJarClassifierFactory.PARAM_CLASSIFIER_JAR_PATH, model.getAbsolutePath()), createEngineDescription(ImsCwbWriter.class, ImsCwbWriter.PARAM_TARGET_LOCATION, tsv));
}
};
ParameterSpace pSpace = new ParameterSpace(Dimension.create("corpusPath", CORPUS_PATH), Dimension.create("iterations", 20, 50, 100), Dimension.create("cutoff", 5));
featureExtractionTask.addImport(preprocessingTask, "XMI");
trainingTask.addImport(featureExtractionTask, "MODEL");
analysisTask.addImport(trainingTask, "MODEL");
DefaultBatchTask batch = new DefaultBatchTask();
batch.setParameterSpace(pSpace);
batch.setExecutionPolicy(ExecutionPolicy.USE_EXISTING);
batch.addTask(preprocessingTask);
batch.addTask(featureExtractionTask);
batch.addTask(trainingTask);
batch.addTask(analysisTask);
Lab.getInstance().run(batch);
}
use of org.dkpro.lab.task.ParameterSpace in project dkpro-tc by dkpro.
the class DeepLearningDl4jSeq2SeqTrainTest method getParameterSpace.
public static ParameterSpace getParameterSpace() throws ResourceInitializationException {
// configure training and test data reader dimension
Map<String, Object> dimReaders = new HashMap<String, Object>();
CollectionReaderDescription train = CollectionReaderFactory.createReaderDescription(TeiReader.class, TeiReader.PARAM_LANGUAGE, "en", TeiReader.PARAM_SOURCE_LOCATION, corpusFilePathTrain, TeiReader.PARAM_PATTERNS, "*.xml");
dimReaders.put(DIM_READER_TRAIN, train);
CollectionReaderDescription test = CollectionReaderFactory.createReaderDescription(TeiReader.class, TeiReader.PARAM_LANGUAGE, "en", TeiReader.PARAM_SOURCE_LOCATION, corpusFilePathTest, TeiReader.PARAM_PATTERNS, "*.xml");
dimReaders.put(DIM_READER_TEST, test);
ParameterSpace pSpace = new ParameterSpace(Dimension.createBundle("readers", dimReaders), Dimension.create(DIM_FEATURE_MODE, Constants.FM_SEQUENCE), Dimension.create(DIM_LEARNING_MODE, Constants.LM_SINGLE_LABEL), Dimension.create(DeepLearningConstants.DIM_PRETRAINED_EMBEDDINGS, "src/test/resources/wordvector/glove.6B.50d_250.txt"), Dimension.create(DeepLearningConstants.DIM_VECTORIZE_TO_INTEGER, false), Dimension.create(DeepLearningConstants.DIM_USE_ONLY_VOCABULARY_COVERED_BY_EMBEDDING, true), Dimension.create(DeepLearningConstants.DIM_USER_CODE, new Dl4jSeq2SeqUserCode()));
return pSpace;
}
use of org.dkpro.lab.task.ParameterSpace in project dkpro-tc by dkpro.
the class DynetDocumentTrainTest method main.
public static void main(String[] args) throws Exception {
// This is used to ensure that the required DKPRO_HOME environment
// variable is set.
// Ensures that people can run the experiments even if they haven't read
// the setup
// instructions first :)
// DemoUtils.setDkproHome(DeepLearningKerasSeq2SeqPoSTestDummy.class.getSimpleName());
System.setProperty("DKPRO_HOME", System.getProperty("user.home") + "/Desktop");
ParameterSpace pSpace = getParameterSpace("/usr/local/bin/python3");
DynetDocumentTrainTest.runTrainTest(pSpace);
}
use of org.dkpro.lab.task.ParameterSpace in project dkpro-tc by dkpro.
the class CRFSuiteBrownPosDemoTest method runTrainTestNoFilter.
@Test
public void runTrainTestNoFilter() throws Exception {
Map<String, Object> config = new HashMap<>();
config.put(DIM_CLASSIFICATION_ARGS, new Object[] { new CrfSuiteAdapter(), CrfSuiteAdapter.ALGORITHM_LBFGS, "-p", "max_iterations=5" });
config.put(DIM_DATA_WRITER, new CrfSuiteAdapter().getDataWriterClass().getName());
config.put(DIM_FEATURE_USE_SPARSE, new CrfSuiteAdapter().useSparseFeatures());
Dimension<Map<String, Object>> mlas = Dimension.createBundle("config", config);
ParameterSpace pSpace = CRFSuiteBrownPosDemoSimpleDkproReader.getParameterSpace(Constants.FM_SEQUENCE, Constants.LM_SINGLE_LABEL, mlas, null);
javaExperiment.runTrainTest(pSpace);
assertEquals(1, ContextMemoryReport.id2outcomeFiles.size());
List<String> lines = FileUtils.readLines(ContextMemoryReport.id2outcomeFiles.get(0), "utf-8");
assertEquals(34, lines.size());
assertEquals("#ID=PREDICTION;GOLDSTANDARD;THRESHOLD", lines.get(0));
assertEquals("#labels 0=NN 1=JJ 2=NP 3=DTS 4=BEDZ 5=HV 6=PPO 7=DT 8=NNS 9=PPS 10=JJT 11=ABX 12=MD 13=DOD 14=VBD 15=VBG 16=QL 32=%28null%29 17=pct 18=CC 19=VBN 20=NPg 21=IN 22=WDT 23=BEN 24=VB 25=BER 26=AP 27=RB 28=CS 29=AT 30=HVD 31=TO", lines.get(1));
// 2nd line time stamp
// Crfsuite results are sensitive to some extend to the platform, to
// account for this sensitivity we check only that the "prediction"
// field is filled with any number but do not test for a specific value
assertTrue(lines.get(3).matches("0000_0000_0000_The=[0-9]+;29;-1"));
assertTrue(lines.get(4).matches("0000_0000_0001_bill=[0-9]+;0;-1"));
assertTrue(lines.get(5).matches("0000_0000_0002_,=[0-9]+;17;-1"));
assertTrue(lines.get(6).matches("0000_0000_0003_which=[0-9]+;22;-1"));
assertTrue(lines.get(7).matches("0000_0000_0004_Daniel=[0-9]+;2;-1"));
}
Aggregations