Search in sources :

Example 11 with PredictorRecord

use of org.knime.base.node.mine.treeensemble2.data.PredictorRecord in project knime-core by knime.

the class LKGradientBoostingPredictorCellFactory method getCells.

/**
 * {@inheritDoc}
 */
@Override
public DataCell[] getCells(final DataRow row) {
    final DataRow filterRow = new FilterColumnRow(row, m_learnColumnInRealDataIndices);
    final int nrClasses = m_model.getNrClasses();
    final int nrLevels = m_model.getNrLevels();
    final PredictorRecord record = m_model.createPredictorRecord(filterRow, m_learnSpec);
    final double[] classFunctionPredictions = new double[nrClasses];
    Arrays.fill(classFunctionPredictions, m_model.getInitialValue());
    for (int i = 0; i < nrLevels; i++) {
        for (int j = 0; j < nrClasses; j++) {
            final TreeNodeRegression matchingNode = m_model.getModel(i, j).findMatchingNode(record);
            classFunctionPredictions[j] += m_model.getCoefficientMap(i, j).get(matchingNode.getSignature());
        }
    }
    final double[] classProbabilities = new double[nrClasses];
    double expSum = 0;
    for (int i = 0; i < nrClasses; i++) {
        classProbabilities[i] = Math.exp(classFunctionPredictions[i]);
        expSum += classProbabilities[i];
    }
    int classIdx = -1;
    double classProb = -1;
    for (int i = 0; i < nrClasses; i++) {
        classProbabilities[i] /= expSum;
        if (classProbabilities[i] > classProb) {
            classIdx = i;
            classProb = classProbabilities[i];
        }
    }
    final ArrayList<DataCell> cells = new ArrayList<DataCell>();
    cells.add(new StringCell(m_model.getClassLabel(classIdx)));
    if (m_config.isAppendPredictionConfidence()) {
        cells.add(new DoubleCell(classProb));
    }
    if (m_config.isAppendClassConfidences()) {
        // the map is necessary to ensure that the probabilities are correctly associated with the column header
        final Map<String, Double> classProbMap = new HashMap<String, Double>((int) (nrClasses * 1.5));
        for (int i = 0; i < nrClasses; i++) {
            classProbMap.put(m_model.getClassLabel(i), classProbabilities[i]);
        }
        for (final String className : m_targetValueMap.keySet()) {
            cells.add(new DoubleCell(classProbMap.get(className)));
        }
    }
    return cells.toArray(new DataCell[cells.size()]);
}
Also used : HashMap(java.util.HashMap) DoubleCell(org.knime.core.data.def.DoubleCell) ArrayList(java.util.ArrayList) DataRow(org.knime.core.data.DataRow) TreeNodeRegression(org.knime.base.node.mine.treeensemble2.model.TreeNodeRegression) StringCell(org.knime.core.data.def.StringCell) PredictorRecord(org.knime.base.node.mine.treeensemble2.data.PredictorRecord) DataCell(org.knime.core.data.DataCell) FilterColumnRow(org.knime.base.data.filter.column.FilterColumnRow)

Example 12 with PredictorRecord

use of org.knime.base.node.mine.treeensemble2.data.PredictorRecord in project knime-core by knime.

the class TreeEnsembleRegressionPredictorCellFactory method getCells.

/**
 * {@inheritDoc}
 */
@Override
public DataCell[] getCells(final DataRow row) {
    TreeEnsembleModelPortObject modelObject = m_predictor.getModelObject();
    TreeEnsemblePredictorConfiguration cfg = m_predictor.getConfiguration();
    final TreeEnsembleModel ensembleModel = modelObject.getEnsembleModel();
    int size = 1;
    final boolean appendConfidence = cfg.isAppendPredictionConfidence();
    final boolean appendModelCount = cfg.isAppendModelCount();
    if (appendConfidence) {
        size += 1;
    }
    if (appendModelCount) {
        size += 1;
    }
    final boolean hasOutOfBagFilter = m_predictor.hasOutOfBagFilter();
    DataCell[] result = new DataCell[size];
    DataRow filterRow = new FilterColumnRow(row, m_learnColumnInRealDataIndices);
    PredictorRecord record = ensembleModel.createPredictorRecord(filterRow, m_learnSpec);
    if (record == null) {
        // missing value
        Arrays.fill(result, DataType.getMissingCell());
        return result;
    }
    Mean mean = new Mean();
    Variance variance = new Variance();
    final int nrModels = ensembleModel.getNrModels();
    for (int i = 0; i < nrModels; i++) {
        if (hasOutOfBagFilter && m_predictor.isRowPartOfTrainingData(row.getKey(), i)) {
        // ignore, row was used to train the model
        } else {
            TreeModelRegression m = ensembleModel.getTreeModelRegression(i);
            TreeNodeRegression match = m.findMatchingNode(record);
            double nodeMean = match.getMean();
            mean.increment(nodeMean);
            variance.increment(nodeMean);
        }
    }
    int nrValidModels = (int) mean.getN();
    int index = 0;
    result[index++] = nrValidModels == 0 ? DataType.getMissingCell() : new DoubleCell(mean.getResult());
    if (appendConfidence) {
        result[index++] = nrValidModels == 0 ? DataType.getMissingCell() : new DoubleCell(variance.getResult());
    }
    if (appendModelCount) {
        result[index++] = new IntCell(nrValidModels);
    }
    return result;
}
Also used : Mean(org.apache.commons.math.stat.descriptive.moment.Mean) TreeEnsembleModel(org.knime.base.node.mine.treeensemble2.model.TreeEnsembleModel) DoubleCell(org.knime.core.data.def.DoubleCell) TreeEnsemblePredictorConfiguration(org.knime.base.node.mine.treeensemble2.node.predictor.TreeEnsemblePredictorConfiguration) DataRow(org.knime.core.data.DataRow) TreeNodeRegression(org.knime.base.node.mine.treeensemble2.model.TreeNodeRegression) Variance(org.apache.commons.math.stat.descriptive.moment.Variance) TreeModelRegression(org.knime.base.node.mine.treeensemble2.model.TreeModelRegression) IntCell(org.knime.core.data.def.IntCell) TreeEnsembleModelPortObject(org.knime.base.node.mine.treeensemble2.model.TreeEnsembleModelPortObject) PredictorRecord(org.knime.base.node.mine.treeensemble2.data.PredictorRecord) DataCell(org.knime.core.data.DataCell) FilterColumnRow(org.knime.base.data.filter.column.FilterColumnRow)

Example 13 with PredictorRecord

use of org.knime.base.node.mine.treeensemble2.data.PredictorRecord in project knime-core by knime.

the class TreeNodeNumericConditionTest method testTestCondition.

/**
 * This method tests the
 * {@link TreeNodeNominalCondition#testCondition(org.knime.base.node.mine.treeensemble2.data.PredictorRecord)}
 * method.
 *
 * @throws Exception
 */
@Test
public void testTestCondition() throws Exception {
    final TreeEnsembleLearnerConfiguration config = new TreeEnsembleLearnerConfiguration(false);
    final TestDataGenerator dataGen = new TestDataGenerator(config);
    final TreeNumericColumnData col = dataGen.createNumericAttributeColumn("1,2,3,4,4,5,6,7", "testCol", 0);
    TreeNodeNumericCondition cond = new TreeNodeNumericCondition(col.getMetaData(), 3, NumericOperator.LessThanOrEqual, false);
    final Map<String, Object> map = Maps.newHashMap();
    final String colName = col.getMetaData().getAttributeName();
    map.put(colName, 2.5);
    final PredictorRecord record = new PredictorRecord(map);
    assertTrue("2.5 was falsely rejected.", cond.testCondition(record));
    map.clear();
    map.put(colName, 3);
    assertTrue("3 was falsely rejected.", cond.testCondition(record));
    map.clear();
    map.put(colName, 4);
    assertFalse("4 was falsely accepted.", cond.testCondition(record));
    map.clear();
    map.put(colName, PredictorRecord.NULL);
    assertFalse("Missing values were falsely accepted.", cond.testCondition(record));
    cond = new TreeNodeNumericCondition(col.getMetaData(), 3, NumericOperator.LessThanOrEqual, true);
    map.clear();
    map.put(colName, 2.5);
    assertTrue("2.5 was falsely rejected.", cond.testCondition(record));
    map.clear();
    map.put(colName, 3);
    assertTrue("3 was falsely rejected.", cond.testCondition(record));
    map.clear();
    map.put(colName, 4);
    assertFalse("4 was falsely accepted.", cond.testCondition(record));
    map.clear();
    map.put(colName, PredictorRecord.NULL);
    assertTrue("Missing values were falsely rejected.", cond.testCondition(record));
    cond = new TreeNodeNumericCondition(col.getMetaData(), 4, NumericOperator.LargerThan, false);
    map.clear();
    map.put(colName, 2.5);
    assertFalse("2.5 was falsely accepted.", cond.testCondition(record));
    map.clear();
    map.put(colName, 3);
    assertFalse("3 was falsely accepted.", cond.testCondition(record));
    map.clear();
    map.put(colName, 4);
    assertFalse("4 was falsely accepted.", cond.testCondition(record));
    map.clear();
    map.put(colName, 4.01);
    assertTrue("4.01 was falsely rejected.", cond.testCondition(record));
    map.clear();
    map.put(colName, PredictorRecord.NULL);
    assertFalse("Missing values were falsely accepted.", cond.testCondition(record));
    cond = new TreeNodeNumericCondition(col.getMetaData(), 4, NumericOperator.LargerThan, true);
    map.clear();
    map.put(colName, 2.5);
    assertFalse("2.5 was falsely accepted.", cond.testCondition(record));
    map.clear();
    map.put(colName, 3);
    assertFalse("3 was falsely accepted.", cond.testCondition(record));
    map.clear();
    map.put(colName, 4.01);
    assertTrue("4 was falsely rejected.", cond.testCondition(record));
    map.clear();
    map.put(colName, PredictorRecord.NULL);
    assertTrue("Missing values were falsely rejected.", cond.testCondition(record));
}
Also used : TreeEnsembleLearnerConfiguration(org.knime.base.node.mine.treeensemble2.node.learner.TreeEnsembleLearnerConfiguration) PredictorRecord(org.knime.base.node.mine.treeensemble2.data.PredictorRecord) TreeNumericColumnData(org.knime.base.node.mine.treeensemble2.data.TreeNumericColumnData) TestDataGenerator(org.knime.base.node.mine.treeensemble2.data.TestDataGenerator) Test(org.junit.Test)

Example 14 with PredictorRecord

use of org.knime.base.node.mine.treeensemble2.data.PredictorRecord in project knime-core by knime.

the class AbstractGradientBoostingLearner method createPredictorRecord.

/**
 * Creates a PredictorRecord from the inMemory TreeData object
 *
 * @param data
 * @param indexManager
 * @param rowIdx
 * @return a PredictorRecord for the row at <b>rowIdx</b> in <b>data</b>
 */
public static PredictorRecord createPredictorRecord(final TreeData data, final IDataIndexManager indexManager, final int rowIdx) {
    Map<String, Object> valMap = new HashMap<String, Object>();
    for (TreeAttributeColumnData column : data.getColumns()) {
        TreeAttributeColumnMetaData meta = column.getMetaData();
        valMap.put(meta.getAttributeName(), handleMissingValues(column.getValueAt(indexManager.getPositionsInColumn(meta.getAttributeIndex())[rowIdx]), column));
    }
    return new PredictorRecord(valMap);
}
Also used : TreeAttributeColumnData(org.knime.base.node.mine.treeensemble2.data.TreeAttributeColumnData) HashMap(java.util.HashMap) PredictorRecord(org.knime.base.node.mine.treeensemble2.data.PredictorRecord) TreeAttributeColumnMetaData(org.knime.base.node.mine.treeensemble2.data.TreeAttributeColumnMetaData)

Example 15 with PredictorRecord

use of org.knime.base.node.mine.treeensemble2.data.PredictorRecord in project knime-core by knime.

the class AbstractTreeEnsembleModel method createBitVectorPredictorRecord.

private PredictorRecord createBitVectorPredictorRecord(final DataRow filterRow) {
    assert filterRow.getNumCells() == 1 : "Expected one cell as bit vector data";
    DataCell c = filterRow.getCell(0);
    if (c.isMissing()) {
        return null;
    }
    BitVectorValue bv = (BitVectorValue) c;
    final long length = bv.length();
    if (length != getMetaData().getNrAttributes()) {
        throw new IllegalArgumentException("The bit-vector in " + filterRow.getKey().getString() + " has the wrong length. (" + length + " instead of " + getMetaData().getNrAttributes() + ")");
    }
    Map<String, Object> valueMap = new LinkedHashMap<String, Object>((int) (length / 0.75 + 1.0));
    for (int i = 0; i < length; i++) {
        valueMap.put(TreeBitColumnMetaData.getAttributeName(i), Boolean.valueOf(bv.get(i)));
    }
    return new PredictorRecord(valueMap);
}
Also used : PredictorRecord(org.knime.base.node.mine.treeensemble2.data.PredictorRecord) DataCell(org.knime.core.data.DataCell) BitVectorValue(org.knime.core.data.vector.bitvector.BitVectorValue) LinkedHashMap(java.util.LinkedHashMap)

Aggregations

PredictorRecord (org.knime.base.node.mine.treeensemble2.data.PredictorRecord)16 DataCell (org.knime.core.data.DataCell)9 FilterColumnRow (org.knime.base.data.filter.column.FilterColumnRow)6 DataRow (org.knime.core.data.DataRow)6 DoubleCell (org.knime.core.data.def.DoubleCell)5 LinkedHashMap (java.util.LinkedHashMap)4 Test (org.junit.Test)3 NominalValueRepresentation (org.knime.base.node.mine.treeensemble2.data.NominalValueRepresentation)3 TestDataGenerator (org.knime.base.node.mine.treeensemble2.data.TestDataGenerator)3 TreeEnsembleModel (org.knime.base.node.mine.treeensemble2.model.TreeEnsembleModel)3 TreeEnsembleModelPortObject (org.knime.base.node.mine.treeensemble2.model.TreeEnsembleModelPortObject)3 TreeNodeRegression (org.knime.base.node.mine.treeensemble2.model.TreeNodeRegression)3 TreeEnsembleLearnerConfiguration (org.knime.base.node.mine.treeensemble2.node.learner.TreeEnsembleLearnerConfiguration)3 TreeEnsemblePredictorConfiguration (org.knime.base.node.mine.treeensemble2.node.predictor.TreeEnsemblePredictorConfiguration)3 IntCell (org.knime.core.data.def.IntCell)3 HashMap (java.util.HashMap)2 TreeData (org.knime.base.node.mine.treeensemble2.data.TreeData)2 TreeNominalColumnData (org.knime.base.node.mine.treeensemble2.data.TreeNominalColumnData)2 TreeTargetNominalColumnMetaData (org.knime.base.node.mine.treeensemble2.data.TreeTargetNominalColumnMetaData)2 IDataIndexManager (org.knime.base.node.mine.treeensemble2.data.memberships.IDataIndexManager)2