Search in sources :

Example 11 with DataMemberships

use of org.knime.base.node.mine.treeensemble2.data.memberships.DataMemberships in project knime-core by knime.

the class TreeNominalColumnDataTest method testCalcBestSplitCassificationBinaryTwoClassXGBoostMissingValue.

/**
 * Tests the XGBoost Missing value handling in case of a two class problem <br>
 * currently not tested because missing value handling will probably be implemented differently.
 *
 * @throws Exception
 */
// @Test
public void testCalcBestSplitCassificationBinaryTwoClassXGBoostMissingValue() throws Exception {
    final TreeEnsembleLearnerConfiguration config = createConfig(false);
    config.setMissingValueHandling(MissingValueHandling.XGBoost);
    final TestDataGenerator dataGen = new TestDataGenerator(config);
    // check correct behavior if no missing values are encountered during split search
    Pair<TreeNominalColumnData, TreeTargetNominalColumnData> twoClassTennisData = twoClassTennisData(config);
    TreeData treeData = dataGen.createTreeData(twoClassTennisData.getSecond(), twoClassTennisData.getFirst());
    IDataIndexManager indexManager = new DefaultDataIndexManager(treeData);
    double[] rowWeights = new double[TWO_CLASS_INDICES.length];
    Arrays.fill(rowWeights, 1.0);
    // DataMemberships dataMemberships = TestDataGenerator.createMockDataMemberships(TWO_CLASS_INDICES.length);
    DataMemberships dataMemberships = new RootDataMemberships(rowWeights, treeData, indexManager);
    TreeTargetNominalColumnData targetData = twoClassTennisData.getSecond();
    TreeNominalColumnData columnData = twoClassTennisData.getFirst();
    ClassificationPriors priors = targetData.getDistribution(rowWeights, config);
    RandomData rd = TestDataGenerator.createRandomData();
    SplitCandidate splitCandidate = columnData.calcBestSplitClassification(dataMemberships, priors, targetData, rd);
    assertNotNull(splitCandidate);
    assertThat(splitCandidate, instanceOf(NominalBinarySplitCandidate.class));
    NominalBinarySplitCandidate binarySplitCandidate = (NominalBinarySplitCandidate) splitCandidate;
    TreeNodeNominalBinaryCondition[] childConditions = binarySplitCandidate.getChildConditions();
    assertEquals(2, childConditions.length);
    assertArrayEquals(new String[] { "R" }, childConditions[0].getValues());
    assertArrayEquals(new String[] { "R" }, childConditions[1].getValues());
    assertEquals(SetLogic.IS_NOT_IN, childConditions[0].getSetLogic());
    assertEquals(SetLogic.IS_IN, childConditions[1].getSetLogic());
    // check if missing values go left
    assertTrue(childConditions[0].acceptsMissings());
    assertFalse(childConditions[1].acceptsMissings());
    // check correct behavior if missing values are encountered during split search
    String dataContainingMissingsCSV = "S,?,O,R,S,R,S,O,O,?";
    columnData = dataGen.createNominalAttributeColumn(dataContainingMissingsCSV, "column containing missing values", 0);
    treeData = dataGen.createTreeData(targetData, columnData);
    indexManager = new DefaultDataIndexManager(treeData);
    dataMemberships = new RootDataMemberships(rowWeights, treeData, indexManager);
    splitCandidate = columnData.calcBestSplitClassification(dataMemberships, priors, targetData, null);
    assertNotNull(splitCandidate);
    binarySplitCandidate = (NominalBinarySplitCandidate) splitCandidate;
    assertEquals("Gain was not as expected", 0.08, binarySplitCandidate.getGainValue(), 1e-8);
    childConditions = binarySplitCandidate.getChildConditions();
    String[] conditionValues = new String[] { "O", "?" };
    assertArrayEquals("Values in nominal condition did not match", conditionValues, childConditions[0].getValues());
    assertArrayEquals("Values in nominal condition did not match", conditionValues, childConditions[1].getValues());
    assertEquals("Wrong set logic.", SetLogic.IS_NOT_IN, childConditions[0].getSetLogic());
    assertEquals("Wrong set logic.", SetLogic.IS_IN, childConditions[1].getSetLogic());
    assertFalse("Missig values are not sent to the correct child.", childConditions[0].acceptsMissings());
    assertTrue("Missig values are not sent to the correct child.", childConditions[1].acceptsMissings());
}
Also used : TreeEnsembleLearnerConfiguration(org.knime.base.node.mine.treeensemble2.node.learner.TreeEnsembleLearnerConfiguration) RootDataMemberships(org.knime.base.node.mine.treeensemble2.data.memberships.RootDataMemberships) RandomData(org.apache.commons.math.random.RandomData) IDataIndexManager(org.knime.base.node.mine.treeensemble2.data.memberships.IDataIndexManager) NominalMultiwaySplitCandidate(org.knime.base.node.mine.treeensemble2.learner.NominalMultiwaySplitCandidate) NominalBinarySplitCandidate(org.knime.base.node.mine.treeensemble2.learner.NominalBinarySplitCandidate) SplitCandidate(org.knime.base.node.mine.treeensemble2.learner.SplitCandidate) DefaultDataIndexManager(org.knime.base.node.mine.treeensemble2.data.memberships.DefaultDataIndexManager) DataMemberships(org.knime.base.node.mine.treeensemble2.data.memberships.DataMemberships) RootDataMemberships(org.knime.base.node.mine.treeensemble2.data.memberships.RootDataMemberships) TreeNodeNominalBinaryCondition(org.knime.base.node.mine.treeensemble2.model.TreeNodeNominalBinaryCondition) NominalBinarySplitCandidate(org.knime.base.node.mine.treeensemble2.learner.NominalBinarySplitCandidate)

Example 12 with DataMemberships

use of org.knime.base.node.mine.treeensemble2.data.memberships.DataMemberships in project knime-core by knime.

the class TreeBitVectorColumnData method calcBestSplitRegression.

/**
 * {@inheritDoc}
 */
@Override
public SplitCandidate calcBestSplitRegression(final DataMemberships dataMemberships, final RegressionPriors targetPriors, final TreeTargetNumericColumnData targetColumn, final RandomData rd) {
    final double ySumTotal = targetPriors.getYSum();
    final double nrRecordsTotal = targetPriors.getNrRecords();
    final double criterionTotal = ySumTotal * ySumTotal / nrRecordsTotal;
    final int minChildSize = getConfiguration().getMinChildSize();
    final ColumnMemberships columnMemberships = dataMemberships.getColumnMemberships(getMetaData().getAttributeIndex());
    double onWeights = 0.0;
    double offWeights = 0.0;
    double ySumOn = 0.0;
    double ySumOff = 0.0;
    while (columnMemberships.next()) {
        final double weight = columnMemberships.getRowWeight();
        if (weight < EPSILON) {
        // ignore record: not in current branch or not in sample
        } else {
            final double y = targetColumn.getValueFor(columnMemberships.getOriginalIndex());
            if (m_columnBitSet.get(columnMemberships.getIndexInColumn())) {
                onWeights += weight;
                ySumOn += weight * y;
            } else {
                offWeights += weight;
                ySumOff += weight * y;
            }
        }
    }
    if (onWeights < minChildSize || offWeights < minChildSize) {
        return null;
    }
    final double onCriterion = ySumOn * ySumOn / onWeights;
    final double offCriterion = ySumOff * ySumOff / offWeights;
    final double gain = onCriterion + offCriterion - criterionTotal;
    if (gain > 0) {
        return new BitSplitCandidate(this, gain);
    }
    return null;
}
Also used : BitSplitCandidate(org.knime.base.node.mine.treeensemble2.learner.BitSplitCandidate) ColumnMemberships(org.knime.base.node.mine.treeensemble2.data.memberships.ColumnMemberships)

Example 13 with DataMemberships

use of org.knime.base.node.mine.treeensemble2.data.memberships.DataMemberships in project knime-core by knime.

the class TreeBitVectorColumnData method updateChildMemberships.

/**
 * {@inheritDoc}
 */
@Override
public BitSet updateChildMemberships(final TreeNodeCondition childCondition, final DataMemberships parentMemberships) {
    TreeNodeBitCondition bitCondition = (TreeNodeBitCondition) childCondition;
    assert getMetaData().getAttributeName().equals(bitCondition.getColumnMetaData().getAttributeName());
    final boolean value = bitCondition.getValue();
    final ColumnMemberships columnMemberships = parentMemberships.getColumnMemberships(getMetaData().getAttributeIndex());
    BitSet inChild = new BitSet(columnMemberships.size());
    columnMemberships.reset();
    columnMemberships.next();
    for (int i = columnMemberships.getIndexInColumn(); ; i = columnMemberships.getIndexInColumn()) {
        if (m_columnBitSet.get(i) == value) {
            inChild.set(columnMemberships.getIndexInDataMemberships());
        }
        if (!columnMemberships.next()) {
            break;
        }
    }
    return inChild;
}
Also used : BitSet(java.util.BitSet) ColumnMemberships(org.knime.base.node.mine.treeensemble2.data.memberships.ColumnMemberships) TreeNodeBitCondition(org.knime.base.node.mine.treeensemble2.model.TreeNodeBitCondition)

Example 14 with DataMemberships

use of org.knime.base.node.mine.treeensemble2.data.memberships.DataMemberships in project knime-core by knime.

the class TreeBitVectorColumnData method calcBestSplitClassification.

/**
 * {@inheritDoc}
 */
@Override
public SplitCandidate calcBestSplitClassification(final DataMemberships dataMemberships, final ClassificationPriors targetPriors, final TreeTargetNominalColumnData targetColumn, final RandomData rd) {
    final NominalValueRepresentation[] targetVals = targetColumn.getMetaData().getValues();
    final IImpurity impurityCriterion = targetPriors.getImpurityCriterion();
    final int minChildSize = getConfiguration().getMinChildSize();
    // distribution of target for On ('1') and Off ('0') bits
    final double[] onTargetWeights = new double[targetVals.length];
    final double[] offTargetWeights = new double[targetVals.length];
    double onWeights = 0.0;
    double offWeights = 0.0;
    final ColumnMemberships columnMemberships = dataMemberships.getColumnMemberships(getMetaData().getAttributeIndex());
    while (columnMemberships.next()) {
        final double weight = columnMemberships.getRowWeight();
        if (weight < EPSILON) {
            // ignore record: not in current branch or not in sample
            assert false : "This code should never be reached!";
        } else {
            final int target = targetColumn.getValueFor(columnMemberships.getOriginalIndex());
            if (m_columnBitSet.get(columnMemberships.getIndexInColumn())) {
                onWeights += weight;
                onTargetWeights[target] += weight;
            } else {
                offWeights += weight;
                offTargetWeights[target] += weight;
            }
        }
    }
    if (onWeights < minChildSize || offWeights < minChildSize) {
        return null;
    }
    final double weightSum = onWeights + offWeights;
    final double onImpurity = impurityCriterion.getPartitionImpurity(onTargetWeights, onWeights);
    final double offImpurity = impurityCriterion.getPartitionImpurity(offTargetWeights, offWeights);
    final double[] partitionWeights = new double[] { onWeights, offWeights };
    final double postSplitImpurity = impurityCriterion.getPostSplitImpurity(new double[] { onImpurity, offImpurity }, partitionWeights, weightSum);
    final double gainValue = impurityCriterion.getGain(targetPriors.getPriorImpurity(), postSplitImpurity, partitionWeights, weightSum);
    return new BitSplitCandidate(this, gainValue);
}
Also used : BitSplitCandidate(org.knime.base.node.mine.treeensemble2.learner.BitSplitCandidate) ColumnMemberships(org.knime.base.node.mine.treeensemble2.data.memberships.ColumnMemberships) IImpurity(org.knime.base.node.mine.treeensemble2.learner.IImpurity)

Example 15 with DataMemberships

use of org.knime.base.node.mine.treeensemble2.data.memberships.DataMemberships in project knime-core by knime.

the class TreeNumericColumnData method updateChildMemberships.

@Override
public BitSet updateChildMemberships(final TreeNodeCondition childCondition, final DataMemberships parentMemberships) {
    final TreeNodeNumericCondition numCondition = (TreeNodeNumericCondition) childCondition;
    final NumericOperator numOperator = numCondition.getNumericOperator();
    final double splitValue = numCondition.getSplitValue();
    final ColumnMemberships columnMemberships = parentMemberships.getColumnMemberships(getMetaData().getAttributeIndex());
    columnMemberships.reset();
    final BitSet inChild = new BitSet(columnMemberships.size());
    int startIndex = 0;
    // }
    if (!columnMemberships.nextIndexFrom(startIndex)) {
        throw new IllegalStateException("The current columnMemberships object contains no element that satisfies the splitcondition");
    }
    final int lengthNonMissing = getLengthNonMissing();
    do {
        final double value = getSorted(columnMemberships.getIndexInColumn());
        boolean matches;
        switch(numOperator) {
            case LessThanOrEqual:
                matches = value <= splitValue;
                break;
            case LargerThan:
                matches = value > splitValue;
                break;
            case LessThanOrEqualOrMissing:
                matches = Double.isNaN(value) ? true : value <= splitValue;
                break;
            case LargerThanOrMissing:
                matches = Double.isNaN(value) ? true : value > splitValue;
                break;
            default:
                throw new IllegalStateException("Unknown operator " + numOperator);
        }
        if (matches) {
            inChild.set(columnMemberships.getIndexInDataMemberships());
        }
    } while (columnMemberships.next() && columnMemberships.getIndexInColumn() < lengthNonMissing);
    // reached end of columnMemberships
    if (columnMemberships.getIndexInColumn() < lengthNonMissing) {
        return inChild;
    }
    // handle missing values
    if (numOperator.equals(NumericOperator.LessThanOrEqualOrMissing) || numOperator.equals(NumericOperator.LargerThanOrMissing) || numCondition.acceptsMissings()) {
        do {
            inChild.set(columnMemberships.getIndexInDataMemberships());
        } while (columnMemberships.next());
    }
    return inChild;
}
Also used : TreeNodeNumericCondition(org.knime.base.node.mine.treeensemble2.model.TreeNodeNumericCondition) BitSet(java.util.BitSet) ColumnMemberships(org.knime.base.node.mine.treeensemble2.data.memberships.ColumnMemberships) NumericOperator(org.knime.base.node.mine.treeensemble2.model.TreeNodeNumericCondition.NumericOperator)

Aggregations

TreeEnsembleLearnerConfiguration (org.knime.base.node.mine.treeensemble2.node.learner.TreeEnsembleLearnerConfiguration)34 DataMemberships (org.knime.base.node.mine.treeensemble2.data.memberships.DataMemberships)26 RootDataMemberships (org.knime.base.node.mine.treeensemble2.data.memberships.RootDataMemberships)25 BitSet (java.util.BitSet)21 Test (org.junit.Test)21 SplitCandidate (org.knime.base.node.mine.treeensemble2.learner.SplitCandidate)17 RandomData (org.apache.commons.math.random.RandomData)15 DefaultDataIndexManager (org.knime.base.node.mine.treeensemble2.data.memberships.DefaultDataIndexManager)14 NominalBinarySplitCandidate (org.knime.base.node.mine.treeensemble2.learner.NominalBinarySplitCandidate)13 IDataIndexManager (org.knime.base.node.mine.treeensemble2.data.memberships.IDataIndexManager)12 NominalMultiwaySplitCandidate (org.knime.base.node.mine.treeensemble2.learner.NominalMultiwaySplitCandidate)12 TreeData (org.knime.base.node.mine.treeensemble2.data.TreeData)10 ColumnMemberships (org.knime.base.node.mine.treeensemble2.data.memberships.ColumnMemberships)10 TreeAttributeColumnData (org.knime.base.node.mine.treeensemble2.data.TreeAttributeColumnData)9 TreeNodeNominalBinaryCondition (org.knime.base.node.mine.treeensemble2.model.TreeNodeNominalBinaryCondition)9 NumericSplitCandidate (org.knime.base.node.mine.treeensemble2.learner.NumericSplitCandidate)7 TreeNodeNumericCondition (org.knime.base.node.mine.treeensemble2.model.TreeNodeNumericCondition)7 TreeNodeCondition (org.knime.base.node.mine.treeensemble2.model.TreeNodeCondition)6 TreeTargetNominalColumnData (org.knime.base.node.mine.treeensemble2.data.TreeTargetNominalColumnData)5 NumericMissingSplitCandidate (org.knime.base.node.mine.treeensemble2.learner.NumericMissingSplitCandidate)5