Search in sources :

Example 1 with Add

use of weka.filters.unsupervised.attribute.Add in project dkpro-tc by dkpro.

the class WekaUtils method makeOutcomeClassesCompatible.

/**
 * Adapts the test data class labels to the training data. Class labels from the test data
 * unseen in the training data will be deleted from the test data. Class labels from the
 * training data unseen in the test data will be added to the test data. If training and test
 * class labels are equal, nothing will be done.
 *
 * @param trainData
 *            train data
 * @param testData
 *            test data
 * @param multilabel
 *            is multilable
 * @return instance
 * @throws Exception
 *             in case of error
 */
@SuppressWarnings({ "rawtypes", "unchecked" })
public static Instances makeOutcomeClassesCompatible(Instances trainData, Instances testData, boolean multilabel) throws Exception {
    // new (compatible) test data
    Instances compTestData = null;
    // ================ SINGLE LABEL BRANCH ======================
    if (!multilabel) {
        // retrieve class labels
        Enumeration trainOutcomeValues = trainData.classAttribute().enumerateValues();
        Enumeration testOutcomeValues = testData.classAttribute().enumerateValues();
        ArrayList trainLabels = Collections.list(trainOutcomeValues);
        ArrayList testLabels = Collections.list(testOutcomeValues);
        // add new outcome class attribute to test data
        Add addFilter = new Add();
        addFilter.setNominalLabels(StringUtils.join(trainLabels, ','));
        addFilter.setAttributeName(Constants.CLASS_ATTRIBUTE_NAME + COMPATIBLE_OUTCOME_CLASS);
        addFilter.setInputFormat(testData);
        testData = Filter.useFilter(testData, addFilter);
        // fill NEW test data with values from old test data plus the new class attribute
        compTestData = new Instances(testData, testData.numInstances());
        for (int i = 0; i < testData.numInstances(); i++) {
            weka.core.Instance instance = testData.instance(i);
            String label = (String) testLabels.get((int) instance.value(testData.classAttribute()));
            if (trainLabels.indexOf(label) != -1) {
                instance.setValue(testData.attribute(Constants.CLASS_ATTRIBUTE_NAME + COMPATIBLE_OUTCOME_CLASS), label);
            } else {
                instance.setMissing(testData.classIndex());
            }
            compTestData.add(instance);
        }
        // remove old class attribute
        Remove remove = new Remove();
        remove.setAttributeIndices(Integer.toString(compTestData.attribute(Constants.CLASS_ATTRIBUTE_NAME).index() + 1));
        remove.setInvertSelection(false);
        remove.setInputFormat(compTestData);
        compTestData = Filter.useFilter(compTestData, remove);
        // set new class attribute
        compTestData.setClass(compTestData.attribute(Constants.CLASS_ATTRIBUTE_NAME + COMPATIBLE_OUTCOME_CLASS));
    } else // ================ MULTI LABEL BRANCH ======================
    {
        int numTrainLabels = trainData.classIndex();
        int numTestLabels = testData.classIndex();
        ArrayList<String> trainLabels = getLabels(trainData);
        // ArrayList<String> testLabels = getLabels(testData);
        // add new outcome class attributes to test data
        Add filter = new Add();
        for (int i = 0; i < numTrainLabels; i++) {
            // numTestLabels +i (because index starts from 0)
            filter.setAttributeIndex(Integer.toString(numTestLabels + i + 1));
            filter.setNominalLabels("0,1");
            filter.setAttributeName(trainData.attribute(i).name() + COMPATIBLE_OUTCOME_CLASS);
            filter.setInputFormat(testData);
            testData = Filter.useFilter(testData, filter);
        }
        // fill NEW test data with values from old test data plus the new class attributes
        compTestData = new Instances(testData, testData.numInstances());
        for (int i = 0; i < testData.numInstances(); i++) {
            weka.core.Instance instance = testData.instance(i);
            // fullfill with 0.
            for (int j = 0; j < numTrainLabels; j++) {
                instance.setValue(j + numTestLabels, 0.);
            }
            // fill the real values:
            for (int j = 0; j < numTestLabels; j++) {
                // part of train data: forget labels which are not part of the train data
                if (trainLabels.indexOf(instance.attribute(j).name()) != -1) {
                    // class label found in test data
                    int index = trainLabels.indexOf(instance.attribute(j).name());
                    instance.setValue(index + numTestLabels, instance.value(j));
                }
            }
            compTestData.add(instance);
        }
        // remove old class attributes
        for (int i = 0; i < numTestLabels; i++) {
            Remove remove = new Remove();
            remove.setAttributeIndices("1");
            remove.setInvertSelection(false);
            remove.setInputFormat(compTestData);
            compTestData = Filter.useFilter(compTestData, remove);
        }
        // adapt header and set new class label
        String relationTag = compTestData.relationName();
        compTestData.setRelationName(relationTag.substring(0, relationTag.indexOf("-C") + 2) + " " + numTrainLabels + " ");
        compTestData.setClassIndex(numTrainLabels);
    }
    return compTestData;
}
Also used : Instances(weka.core.Instances) MultiLabelInstances(mulan.data.MultiLabelInstances) Add(weka.filters.unsupervised.attribute.Add) Enumeration(java.util.Enumeration) ArrayList(java.util.ArrayList) Remove(weka.filters.unsupervised.attribute.Remove)

Example 2 with Add

use of weka.filters.unsupervised.attribute.Add in project dkpro-tc by dkpro.

the class WekaUtils method getPredictionInstancesMultiLabel.

/**
 * Generates an instances object containing the predictions of a given multi-label classifier
 * for a given test set
 *
 * @param testData
 *            test set
 * @param cl
 *            multi-label classifier, needs not to be trained beforehand, needs to be compatible
 *            with the test set
 * @param thresholdArray
 *            an array of double, one for each label
 * @return instances object with additional attribute storing the predictions
 * @throws Exception
 *             an exception
 */
public static Instances getPredictionInstancesMultiLabel(Instances testData, Classifier cl, double[] thresholdArray) throws Exception {
    int numLabels = testData.classIndex();
    // get predictions
    List<double[]> labelPredictionList = new ArrayList<double[]>();
    for (int i = 0; i < testData.numInstances(); i++) {
        labelPredictionList.add(cl.distributionForInstance(testData.instance(i)));
    }
    // add attributes to store predictions in test data
    Add filter = new Add();
    for (int i = 0; i < numLabels; i++) {
        filter.setAttributeIndex(Integer.toString(numLabels + i + 1));
        filter.setNominalLabels("0,1");
        filter.setAttributeName(testData.attribute(i).name() + "_" + WekaTestTask.PREDICTION_CLASS_LABEL_NAME);
        filter.setInputFormat(testData);
        testData = Filter.useFilter(testData, filter);
    }
    // fill predicted values for each instance
    for (int i = 0; i < labelPredictionList.size(); i++) {
        for (int j = 0; j < labelPredictionList.get(i).length; j++) {
            testData.instance(i).setValue(j + numLabels, labelPredictionList.get(i)[j] >= thresholdArray[j] ? 1. : 0.);
        }
    }
    return testData;
}
Also used : Add(weka.filters.unsupervised.attribute.Add) ArrayList(java.util.ArrayList)

Example 3 with Add

use of weka.filters.unsupervised.attribute.Add in project dkpro-tc by dkpro.

the class WekaUtils method getPredictionInstancesSingleLabel.

/**
 * Generates an instances object containing the predictions of a given single-label classifier
 * for a given test set
 *
 * @param testData
 *            weka instances
 * @param cl
 *            classifier
 * @return weka instances
 * @throws Exception
 *             in case of errors
 */
public static Instances getPredictionInstancesSingleLabel(Instances testData, Classifier cl) throws Exception {
    StringBuffer classVals = new StringBuffer();
    for (int i = 0; i < testData.classAttribute().numValues(); i++) {
        if (classVals.length() > 0) {
            classVals.append(",");
        }
        classVals.append(testData.classAttribute().value(i));
    }
    // get predictions
    List<Double> labelPredictionList = new ArrayList<Double>();
    for (int i = 0; i < testData.size(); i++) {
        labelPredictionList.add(cl.classifyInstance(testData.instance(i)));
    }
    // add an attribute with the predicted values at the end off the attributes
    Add filter = new Add();
    filter.setAttributeName(WekaTestTask.PREDICTION_CLASS_LABEL_NAME);
    if (classVals.length() > 0) {
        filter.setAttributeType(new SelectedTag(Attribute.NOMINAL, Add.TAGS_TYPE));
        filter.setNominalLabels(classVals.toString());
    }
    filter.setInputFormat(testData);
    testData = Filter.useFilter(testData, filter);
    // fill predicted values for each instance
    for (int i = 0; i < labelPredictionList.size(); i++) {
        testData.instance(i).setValue(testData.classIndex() + 1, labelPredictionList.get(i));
    }
    return testData;
}
Also used : Add(weka.filters.unsupervised.attribute.Add) SelectedTag(weka.core.SelectedTag) ArrayList(java.util.ArrayList)

Example 4 with Add

use of weka.filters.unsupervised.attribute.Add in project dkpro-tc by dkpro.

the class WekaUtils method addInstanceId.

/**
 * Copies the instanceId attribute and its values from an existing data set, iff present. It
 * will be indexed right before the class attribute
 *
 * @param newData
 *            data set without instanceId attribute
 * @param oldData
 *            data set with or without instanceId attribute
 * @param isMultilabel
 *            is multi label processing
 * @return a data set with or without instanceId attribute
 * @throws Exception
 *             an exception
 */
public static Instances addInstanceId(Instances newData, Instances oldData, boolean isMultilabel) throws Exception {
    Instances filteredData;
    if (oldData.attribute(Constants.ID_FEATURE_NAME) != null) {
        int instanceIdOffset = oldData.attribute(Constants.ID_FEATURE_NAME).index();
        Add add = new Add();
        add.setAttributeName(Constants.ID_FEATURE_NAME);
        // for single-label
        if (isMultilabel) {
            add.setAttributeIndex("last");
        } else {
            add.setAttributeIndex("first");
        }
        add.setAttributeType(new SelectedTag(Attribute.STRING, Add.TAGS_TYPE));
        add.setInputFormat(newData);
        filteredData = Filter.useFilter(newData, add);
        int j = isMultilabel ? filteredData.numAttributes() - 1 : 0;
        for (int i = 0; i < filteredData.numInstances(); i++) {
            String outcomeId = oldData.instance(i).stringValue(instanceIdOffset);
            filteredData.instance(i).setValue(j, outcomeId);
        }
    } else {
        filteredData = new Instances(newData);
    }
    return filteredData;
}
Also used : Instances(weka.core.Instances) MultiLabelInstances(mulan.data.MultiLabelInstances) Add(weka.filters.unsupervised.attribute.Add) SelectedTag(weka.core.SelectedTag)

Aggregations

Add (weka.filters.unsupervised.attribute.Add)4 ArrayList (java.util.ArrayList)3 MultiLabelInstances (mulan.data.MultiLabelInstances)2 Instances (weka.core.Instances)2 SelectedTag (weka.core.SelectedTag)2 Enumeration (java.util.Enumeration)1 Remove (weka.filters.unsupervised.attribute.Remove)1