Search in sources :

Example 16 with NominalValueRepresentation

use of org.knime.base.node.mine.treeensemble2.data.NominalValueRepresentation in project knime-core by knime.

the class TreeModelPMMLTranslator method addTreeNode.

/**
 * @param pmmlNode
 * @param node
 */
private void addTreeNode(final Node pmmlNode, final AbstractTreeNode node) {
    int index = m_nodeIndex++;
    pmmlNode.setId(Integer.toString(index));
    if (node instanceof TreeNodeClassification) {
        final TreeNodeClassification clazNode = (TreeNodeClassification) node;
        pmmlNode.setScore(clazNode.getMajorityClassName());
        float[] targetDistribution = clazNode.getTargetDistribution();
        NominalValueRepresentation[] targetVals = clazNode.getTargetMetaData().getValues();
        double sum = 0.0;
        for (Float v : targetDistribution) {
            sum += v;
        }
        pmmlNode.setRecordCount(sum);
        // adding score distribution (class counts)
        for (int i = 0; i < targetDistribution.length; i++) {
            String className = targetVals[i].getNominalValue();
            double freq = targetDistribution[i];
            ScoreDistribution pmmlScoreDist = pmmlNode.addNewScoreDistribution();
            pmmlScoreDist.setValue(className);
            pmmlScoreDist.setRecordCount(freq);
        }
    } else if (node instanceof TreeNodeRegression) {
        final TreeNodeRegression regNode = (TreeNodeRegression) node;
        pmmlNode.setScore(Double.toString(regNode.getMean()));
    }
    TreeNodeCondition condition = node.getCondition();
    if (condition instanceof TreeNodeTrueCondition) {
        pmmlNode.addNewTrue();
    } else if (condition instanceof TreeNodeColumnCondition) {
        final TreeNodeColumnCondition colCondition = (TreeNodeColumnCondition) condition;
        handleColumnCondition(colCondition, pmmlNode);
    } else if (condition instanceof AbstractTreeNodeSurrogateCondition) {
        final AbstractTreeNodeSurrogateCondition surrogateCond = (AbstractTreeNodeSurrogateCondition) condition;
        setValuesFromPMMLCompoundPredicate(pmmlNode.addNewCompoundPredicate(), surrogateCond.toPMMLPredicate());
    } else {
        throw new IllegalStateException("Unsupported condition (not " + "implemented): " + condition.getClass().getSimpleName());
    }
    for (int i = 0; i < node.getNrChildren(); i++) {
        addTreeNode(pmmlNode.addNewNode(), node.getChild(i));
    }
}
Also used : NominalValueRepresentation(org.knime.base.node.mine.treeensemble2.data.NominalValueRepresentation) ScoreDistribution(org.dmg.pmml.ScoreDistributionDocument.ScoreDistribution)

Example 17 with NominalValueRepresentation

use of org.knime.base.node.mine.treeensemble2.data.NominalValueRepresentation in project knime-core by knime.

the class TreeBitVectorColumnData method calcBestSplitClassification.

/**
 * {@inheritDoc}
 */
@Override
public SplitCandidate calcBestSplitClassification(final DataMemberships dataMemberships, final ClassificationPriors targetPriors, final TreeTargetNominalColumnData targetColumn, final RandomData rd) {
    final NominalValueRepresentation[] targetVals = targetColumn.getMetaData().getValues();
    final IImpurity impurityCriterion = targetPriors.getImpurityCriterion();
    final int minChildSize = getConfiguration().getMinChildSize();
    // distribution of target for On ('1') and Off ('0') bits
    final double[] onTargetWeights = new double[targetVals.length];
    final double[] offTargetWeights = new double[targetVals.length];
    double onWeights = 0.0;
    double offWeights = 0.0;
    final ColumnMemberships columnMemberships = dataMemberships.getColumnMemberships(getMetaData().getAttributeIndex());
    while (columnMemberships.next()) {
        final double weight = columnMemberships.getRowWeight();
        if (weight < EPSILON) {
            // ignore record: not in current branch or not in sample
            assert false : "This code should never be reached!";
        } else {
            final int target = targetColumn.getValueFor(columnMemberships.getOriginalIndex());
            if (m_columnBitSet.get(columnMemberships.getIndexInColumn())) {
                onWeights += weight;
                onTargetWeights[target] += weight;
            } else {
                offWeights += weight;
                offTargetWeights[target] += weight;
            }
        }
    }
    if (onWeights < minChildSize || offWeights < minChildSize) {
        return null;
    }
    final double weightSum = onWeights + offWeights;
    final double onImpurity = impurityCriterion.getPartitionImpurity(onTargetWeights, onWeights);
    final double offImpurity = impurityCriterion.getPartitionImpurity(offTargetWeights, offWeights);
    final double[] partitionWeights = new double[] { onWeights, offWeights };
    final double postSplitImpurity = impurityCriterion.getPostSplitImpurity(new double[] { onImpurity, offImpurity }, partitionWeights, weightSum);
    final double gainValue = impurityCriterion.getGain(targetPriors.getPriorImpurity(), postSplitImpurity, partitionWeights, weightSum);
    return new BitSplitCandidate(this, gainValue);
}
Also used : BitSplitCandidate(org.knime.base.node.mine.treeensemble2.learner.BitSplitCandidate) ColumnMemberships(org.knime.base.node.mine.treeensemble2.data.memberships.ColumnMemberships) IImpurity(org.knime.base.node.mine.treeensemble2.learner.IImpurity)

Example 18 with NominalValueRepresentation

use of org.knime.base.node.mine.treeensemble2.data.NominalValueRepresentation in project knime-core by knime.

the class BinaryNominalSplitsPCATest method createTestAttVals.

private static CombinedAttributeValues[] createTestAttVals() {
    CombinedAttributeValues[] attVals = new CombinedAttributeValues[5];
    double[][] classFrequencies = new double[][] { { 40, 10, 10 }, { 10, 40, 10 }, { 20, 30, 10 }, { 20, 15, 25 }, { 10, 5, 45 } };
    double[][] classProbabilities = new double[5][3];
    double totalWeight = 60;
    String[] nomValStrings = new String[] { "A", "B", "C", "D", "E" };
    NominalValueRepresentation[] nomVals = new NominalValueRepresentation[5];
    for (int i = 0; i < 5; i++) {
        nomVals[i] = new NominalValueRepresentation(nomValStrings[i], i, totalWeight);
        for (int j = 0; j < 3; j++) {
            classProbabilities[i][j] = classFrequencies[i][j] / totalWeight;
        }
    }
    for (int i = 0; i < 5; i++) {
        attVals[i] = new CombinedAttributeValues(classFrequencies[i], classProbabilities[i], totalWeight, nomVals[i]);
    }
    return attVals;
}
Also used : NominalValueRepresentation(org.knime.base.node.mine.treeensemble2.data.NominalValueRepresentation) CombinedAttributeValues(org.knime.base.node.mine.treeensemble2.data.BinaryNominalSplitsPCA.CombinedAttributeValues)

Example 19 with NominalValueRepresentation

use of org.knime.base.node.mine.treeensemble2.data.NominalValueRepresentation in project knime-core by knime.

the class TreeNodeClassification method createDecisionTreeNode.

/**
 * Creates DecisionTreeNode model that is used in Decision Tree of KNIME
 *
 * @param idGenerator
 * @param metaData
 * @return a DecisionTreeNode
 */
public DecisionTreeNode createDecisionTreeNode(final MutableInteger idGenerator, final TreeMetaData metaData) {
    DataCell majorityCell = new StringCell(getMajorityClassName());
    final float[] targetDistribution = getTargetDistribution();
    int initSize = (int) (targetDistribution.length / 0.75 + 1.0);
    LinkedHashMap<DataCell, Double> scoreDistributionMap = new LinkedHashMap<DataCell, Double>(initSize);
    NominalValueRepresentation[] targets = getTargetMetaData().getValues();
    for (int i = 0; i < targetDistribution.length; i++) {
        String cl = targets[i].getNominalValue();
        double d = targetDistribution[i];
        scoreDistributionMap.put(new StringCell(cl), d);
    }
    final int nrChildren = getNrChildren();
    if (nrChildren == 0) {
        return new DecisionTreeNodeLeaf(idGenerator.inc(), majorityCell, scoreDistributionMap);
    } else {
        int id = idGenerator.inc();
        DecisionTreeNode[] childNodes = new DecisionTreeNode[nrChildren];
        int splitAttributeIndex = getSplitAttributeIndex();
        assert splitAttributeIndex >= 0 : "non-leaf node has no split";
        String splitAttribute = metaData.getAttributeMetaData(splitAttributeIndex).getAttributeName();
        PMMLPredicate[] childPredicates = new PMMLPredicate[nrChildren];
        for (int i = 0; i < nrChildren; i++) {
            final TreeNodeClassification treeNode = getChild(i);
            TreeNodeCondition cond = treeNode.getCondition();
            childPredicates[i] = cond.toPMMLPredicate();
            childNodes[i] = treeNode.createDecisionTreeNode(idGenerator, metaData);
        }
        return new DecisionTreeNodeSplitPMML(id, majorityCell, scoreDistributionMap, splitAttribute, childPredicates, childNodes);
    }
}
Also used : NominalValueRepresentation(org.knime.base.node.mine.treeensemble2.data.NominalValueRepresentation) PMMLPredicate(org.knime.base.node.mine.decisiontree2.PMMLPredicate) LinkedHashMap(java.util.LinkedHashMap) DecisionTreeNodeLeaf(org.knime.base.node.mine.decisiontree2.model.DecisionTreeNodeLeaf) StringCell(org.knime.core.data.def.StringCell) DecisionTreeNodeSplitPMML(org.knime.base.node.mine.decisiontree2.model.DecisionTreeNodeSplitPMML) DataCell(org.knime.core.data.DataCell) DecisionTreeNode(org.knime.base.node.mine.decisiontree2.model.DecisionTreeNode)

Example 20 with NominalValueRepresentation

use of org.knime.base.node.mine.treeensemble2.data.NominalValueRepresentation in project knime-core by knime.

the class NominalColumnHelperUtil method extractNomValReps.

static NominalValueRepresentation[] extractNomValReps(final Set<DataCell> possibleValues) {
    final NominalValueRepresentation[] nomValReps = new NominalValueRepresentation[possibleValues.size()];
    int i = 0;
    for (DataCell value : possibleValues) {
        nomValReps[i] = new NominalValueRepresentation(value.toString(), i);
        i++;
    }
    return nomValReps;
}
Also used : NominalValueRepresentation(org.knime.base.node.mine.treeensemble2.data.NominalValueRepresentation) DataCell(org.knime.core.data.DataCell)

Aggregations

NominalValueRepresentation (org.knime.base.node.mine.treeensemble2.data.NominalValueRepresentation)14 ColumnMemberships (org.knime.base.node.mine.treeensemble2.data.memberships.ColumnMemberships)6 BigInteger (java.math.BigInteger)5 NominalBinarySplitCandidate (org.knime.base.node.mine.treeensemble2.learner.NominalBinarySplitCandidate)5 DataCell (org.knime.core.data.DataCell)5 ArrayList (java.util.ArrayList)4 BitSet (java.util.BitSet)3 LinkedHashMap (java.util.LinkedHashMap)3 PredictorRecord (org.knime.base.node.mine.treeensemble2.data.PredictorRecord)3 TreeNominalColumnMetaData (org.knime.base.node.mine.treeensemble2.data.TreeNominalColumnMetaData)3 TreeTargetNominalColumnMetaData (org.knime.base.node.mine.treeensemble2.data.TreeTargetNominalColumnMetaData)3 IImpurity (org.knime.base.node.mine.treeensemble2.learner.IImpurity)3 TreeNodeClassification (org.knime.base.node.mine.treeensemble2.model.TreeNodeClassification)3 TreeEnsembleLearnerConfiguration (org.knime.base.node.mine.treeensemble2.node.learner.TreeEnsembleLearnerConfiguration)3 ScoreDistribution (org.dmg.pmml.ScoreDistributionDocument.ScoreDistribution)2 FilterColumnRow (org.knime.base.data.filter.column.FilterColumnRow)2 CombinedAttributeValues (org.knime.base.node.mine.treeensemble2.data.BinaryNominalSplitsPCA.CombinedAttributeValues)2 TreeEnsembleModel (org.knime.base.node.mine.treeensemble2.model.TreeEnsembleModel)2 TreeEnsembleModelPortObject (org.knime.base.node.mine.treeensemble2.model.TreeEnsembleModelPortObject)2 TreeModelClassification (org.knime.base.node.mine.treeensemble2.model.TreeModelClassification)2