Search in sources :

Example 1 with MultilabelResult

use of org.dkpro.tc.ml.weka.util.MultilabelResult in project dkpro-tc by dkpro.

the class WekaOutcomeIDReport method execute.

@Override
public void execute() throws Exception {
    init();
    File arff = WekaUtils.getFile(getContext(), "", FILENAME_PREDICTIONS, AccessMode.READONLY);
    mlResults = WekaUtils.getFile(getContext(), "", WekaTestTask.evaluationBin, AccessMode.READONLY);
    Instances predictions = WekaUtils.getInstances(arff, isMultiLabel);
    List<String> labels = getLabels(isMultiLabel, isRegression);
    Properties props;
    if (isMultiLabel) {
        MultilabelResult r = WekaUtils.readMlResultFromFile(mlResults);
        props = generateMlProperties(predictions, labels, r);
    } else {
        Map<Integer, String> documentIdMap = loadDocumentMap();
        props = generateSlProperties(predictions, isRegression, isUnit, documentIdMap, labels);
    }
    FileWriterWithEncoding fw = null;
    try {
        fw = new FileWriterWithEncoding(getTargetOutputFile(), "utf-8");
        props.store(fw, generateHeader(labels));
    } finally {
        IOUtils.closeQuietly(fw);
    }
}
Also used : Instances(weka.core.Instances) FileWriterWithEncoding(org.apache.commons.io.output.FileWriterWithEncoding) MultilabelResult(org.dkpro.tc.ml.weka.util.MultilabelResult) Properties(java.util.Properties) SortedKeyProperties(org.dkpro.tc.ml.report.util.SortedKeyProperties) File(java.io.File)

Example 2 with MultilabelResult

use of org.dkpro.tc.ml.weka.util.MultilabelResult in project dkpro-tc by dkpro.

the class WekaTestTask method execute.

@Override
public void execute(TaskContext aContext) throws Exception {
    boolean multiLabel = learningMode.equals(Constants.LM_MULTI_LABEL);
    File arffFileTrain = WekaUtils.getFile(aContext, TEST_TASK_INPUT_KEY_TRAINING_DATA, Constants.FILENAME_DATA_IN_CLASSIFIER_FORMAT, AccessMode.READONLY);
    File arffFileTest = WekaUtils.getFile(aContext, TEST_TASK_INPUT_KEY_TEST_DATA, Constants.FILENAME_DATA_IN_CLASSIFIER_FORMAT, AccessMode.READONLY);
    Instances trainData = WekaUtils.getInstances(arffFileTrain, multiLabel);
    Instances testData = WekaUtils.getInstances(arffFileTest, multiLabel);
    // do not balance in regression experiments
    if (!learningMode.equals(Constants.LM_REGRESSION)) {
        testData = WekaUtils.makeOutcomeClassesCompatible(trainData, testData, multiLabel);
    }
    Instances copyTestData = new Instances(testData);
    trainData = WekaUtils.removeInstanceId(trainData, multiLabel);
    testData = WekaUtils.removeInstanceId(testData, multiLabel);
    // FEATURE SELECTION
    if (!learningMode.equals(Constants.LM_MULTI_LABEL)) {
        if (featureSearcher != null && attributeEvaluator != null) {
            AttributeSelection attSel = WekaUtils.featureSelectionSinglelabel(aContext, trainData, featureSearcher, attributeEvaluator);
            File file = WekaUtils.getFile(aContext, "", WekaTestTask.featureSelectionFile, AccessMode.READWRITE);
            FileUtils.writeStringToFile(file, attSel.toResultsString(), "utf-8");
            if (applySelection) {
                Logger.getLogger(getClass()).info("APPLYING FEATURE SELECTION");
                trainData = attSel.reduceDimensionality(trainData);
                testData = attSel.reduceDimensionality(testData);
            }
        }
    } else {
        if (attributeEvaluator != null && labelTransformationMethod != null && numLabelsToKeep > 0) {
            Remove attSel = WekaUtils.featureSelectionMultilabel(aContext, trainData, attributeEvaluator, labelTransformationMethod, numLabelsToKeep);
            if (applySelection) {
                Logger.getLogger(getClass()).info("APPLYING FEATURE SELECTION");
                trainData = WekaUtils.applyAttributeSelectionFilter(trainData, attSel);
                testData = WekaUtils.applyAttributeSelectionFilter(testData, attSel);
            }
        }
    }
    // build classifier
    Classifier cl = WekaUtils.getClassifier(learningMode, classificationArguments);
    // file to hold prediction results
    File evalOutput = WekaUtils.getFile(aContext, "", evaluationBin, AccessMode.READWRITE);
    // evaluation & prediction generation
    if (multiLabel) {
        // we don't need to build the classifier - meka does this
        // internally
        Result r = WekaUtils.getEvaluationMultilabel(cl, trainData, testData, threshold);
        WekaUtils.writeMlResultToFile(new MultilabelResult(r.allTrueValues(), r.allPredictions(), threshold), evalOutput);
        testData = WekaUtils.getPredictionInstancesMultiLabel(testData, cl, WekaUtils.getMekaThreshold(threshold, r, trainData));
        testData = WekaUtils.addInstanceId(testData, copyTestData, true);
    } else {
        // train the classifier on the train set split - not necessary in multilabel setup, but
        // in single label setup
        cl.buildClassifier(trainData);
        weka.core.SerializationHelper.write(evalOutput.getAbsolutePath(), WekaUtils.getEvaluationSinglelabel(cl, trainData, testData));
        testData = WekaUtils.getPredictionInstancesSingleLabel(testData, cl);
        testData = WekaUtils.addInstanceId(testData, copyTestData, false);
    }
    // Write out the prediction - the data sink expects an .arff ending file so we game it a bit
    // and rename the file afterwards to .txt
    File predictionFile = WekaUtils.getFile(aContext, "", Constants.FILENAME_PREDICTIONS, AccessMode.READWRITE);
    File arffDummy = new File(predictionFile.getParent(), "prediction.arff");
    DataSink.write(arffDummy.getAbsolutePath(), testData);
    FileUtils.moveFile(arffDummy, predictionFile);
}
Also used : Instances(weka.core.Instances) AttributeSelection(weka.attributeSelection.AttributeSelection) MultilabelResult(org.dkpro.tc.ml.weka.util.MultilabelResult) Remove(weka.filters.unsupervised.attribute.Remove) Classifier(weka.classifiers.Classifier) File(java.io.File) MultilabelResult(org.dkpro.tc.ml.weka.util.MultilabelResult) Result(meka.core.Result)

Aggregations

File (java.io.File)2 MultilabelResult (org.dkpro.tc.ml.weka.util.MultilabelResult)2 Instances (weka.core.Instances)2 Properties (java.util.Properties)1 Result (meka.core.Result)1 FileWriterWithEncoding (org.apache.commons.io.output.FileWriterWithEncoding)1 SortedKeyProperties (org.dkpro.tc.ml.report.util.SortedKeyProperties)1 AttributeSelection (weka.attributeSelection.AttributeSelection)1 Classifier (weka.classifiers.Classifier)1 Remove (weka.filters.unsupervised.attribute.Remove)1