Search in sources :

Example 6 with RegressionPriors

use of org.knime.base.node.mine.treeensemble2.data.RegressionPriors in project knime-core by knime.

the class TreeBitVectorColumnData method calcBestSplitRegression.

/**
 * {@inheritDoc}
 */
@Override
public SplitCandidate calcBestSplitRegression(final DataMemberships dataMemberships, final RegressionPriors targetPriors, final TreeTargetNumericColumnData targetColumn, final RandomData rd) {
    final double ySumTotal = targetPriors.getYSum();
    final double nrRecordsTotal = targetPriors.getNrRecords();
    final double criterionTotal = ySumTotal * ySumTotal / nrRecordsTotal;
    final int minChildSize = getConfiguration().getMinChildSize();
    final ColumnMemberships columnMemberships = dataMemberships.getColumnMemberships(getMetaData().getAttributeIndex());
    double onWeights = 0.0;
    double offWeights = 0.0;
    double ySumOn = 0.0;
    double ySumOff = 0.0;
    while (columnMemberships.next()) {
        final double weight = columnMemberships.getRowWeight();
        if (weight < EPSILON) {
        // ignore record: not in current branch or not in sample
        } else {
            final double y = targetColumn.getValueFor(columnMemberships.getOriginalIndex());
            if (m_columnBitSet.get(columnMemberships.getIndexInColumn())) {
                onWeights += weight;
                ySumOn += weight * y;
            } else {
                offWeights += weight;
                ySumOff += weight * y;
            }
        }
    }
    if (onWeights < minChildSize || offWeights < minChildSize) {
        return null;
    }
    final double onCriterion = ySumOn * ySumOn / onWeights;
    final double offCriterion = ySumOff * ySumOff / offWeights;
    final double gain = onCriterion + offCriterion - criterionTotal;
    if (gain > 0) {
        return new BitSplitCandidate(this, gain);
    }
    return null;
}
Also used : BitSplitCandidate(org.knime.base.node.mine.treeensemble2.learner.BitSplitCandidate) ColumnMemberships(org.knime.base.node.mine.treeensemble2.data.memberships.ColumnMemberships)

Example 7 with RegressionPriors

use of org.knime.base.node.mine.treeensemble2.data.RegressionPriors in project knime-core by knime.

the class TreeTargetNumericColumnDataTest method testGetPriors.

/**
 * Tests the {@link TreeTargetNumericColumnData#getPriors(DataMemberships, TreeEnsembleLearnerConfiguration)} and
 * {@link TreeTargetNumericColumnData#getPriors(double[], TreeEnsembleLearnerConfiguration)} methods.
 */
@Test
public void testGetPriors() {
    String targetCSV = "1,4,3,5,6,7,8,12,22,1";
    // irrelevant but necessary to build TreeDataObject
    String someAttributeCSV = "A,B,A,B,A,A,B,A,A,B";
    TreeEnsembleLearnerConfiguration config = new TreeEnsembleLearnerConfiguration(true);
    TestDataGenerator dataGen = new TestDataGenerator(config);
    TreeTargetNumericColumnData target = TestDataGenerator.createNumericTargetColumn(targetCSV);
    TreeNominalColumnData attribute = dataGen.createNominalAttributeColumn(someAttributeCSV, "test-col", 0);
    TreeData data = new TreeData(new TreeAttributeColumnData[] { attribute }, target, TreeType.Ordinary);
    double[] weights = new double[10];
    Arrays.fill(weights, 1.0);
    DataMemberships rootMem = new RootDataMemberships(weights, data, new DefaultDataIndexManager(data));
    RegressionPriors datMemPriors = target.getPriors(rootMem, config);
    assertEquals(6.9, datMemPriors.getMean(), DELTA);
    assertEquals(69, datMemPriors.getYSum(), DELTA);
    assertEquals(352.9, datMemPriors.getSumSquaredDeviation(), DELTA);
}
Also used : TreeEnsembleLearnerConfiguration(org.knime.base.node.mine.treeensemble2.node.learner.TreeEnsembleLearnerConfiguration) RootDataMemberships(org.knime.base.node.mine.treeensemble2.data.memberships.RootDataMemberships) RootDataMemberships(org.knime.base.node.mine.treeensemble2.data.memberships.RootDataMemberships) DataMemberships(org.knime.base.node.mine.treeensemble2.data.memberships.DataMemberships) DefaultDataIndexManager(org.knime.base.node.mine.treeensemble2.data.memberships.DefaultDataIndexManager) Test(org.junit.Test)

Example 8 with RegressionPriors

use of org.knime.base.node.mine.treeensemble2.data.RegressionPriors in project knime-core by knime.

the class TreeNominalColumnDataTest method testCalcBestSplitRegressionMultiway.

/**
 * Tests the method
 * {@link TreeNominalColumnData#calcBestSplitRegression(DataMemberships, RegressionPriors, TreeTargetNumericColumnData, RandomData)}
 * using multiway splits.
 *
 * @throws Exception
 */
@Test
public void testCalcBestSplitRegressionMultiway() throws Exception {
    TreeEnsembleLearnerConfiguration config = createConfig(true);
    config.setUseBinaryNominalSplits(false);
    Pair<TreeNominalColumnData, TreeTargetNumericColumnData> tennisDataRegression = tennisDataRegression(config);
    TreeNominalColumnData columnData = tennisDataRegression.getFirst();
    TreeTargetNumericColumnData targetData = tennisDataRegression.getSecond();
    TreeData treeData = createTreeDataRegression(tennisDataRegression);
    double[] rowWeights = new double[SMALL_COLUMN_DATA.length];
    Arrays.fill(rowWeights, 1.0);
    IDataIndexManager indexManager = new DefaultDataIndexManager(treeData);
    DataMemberships dataMemberships = new RootDataMemberships(rowWeights, treeData, indexManager);
    RegressionPriors priors = targetData.getPriors(rowWeights, config);
    SplitCandidate splitCandidate = columnData.calcBestSplitRegression(dataMemberships, priors, targetData, null);
    assertNotNull(splitCandidate);
    assertThat(splitCandidate, instanceOf(NominalMultiwaySplitCandidate.class));
    assertFalse(splitCandidate.canColumnBeSplitFurther());
    assertEquals(36.9643, splitCandidate.getGainValue(), 0.0001);
    NominalMultiwaySplitCandidate multiwaySplitCandidate = (NominalMultiwaySplitCandidate) splitCandidate;
    TreeNodeNominalCondition[] childConditions = multiwaySplitCandidate.getChildConditions();
    assertEquals(3, childConditions.length);
    assertEquals("S", childConditions[0].getValue());
    assertEquals("O", childConditions[1].getValue());
    assertEquals("R", childConditions[2].getValue());
}
Also used : TreeEnsembleLearnerConfiguration(org.knime.base.node.mine.treeensemble2.node.learner.TreeEnsembleLearnerConfiguration) RootDataMemberships(org.knime.base.node.mine.treeensemble2.data.memberships.RootDataMemberships) TreeNodeNominalCondition(org.knime.base.node.mine.treeensemble2.model.TreeNodeNominalCondition) IDataIndexManager(org.knime.base.node.mine.treeensemble2.data.memberships.IDataIndexManager) NominalMultiwaySplitCandidate(org.knime.base.node.mine.treeensemble2.learner.NominalMultiwaySplitCandidate) NominalBinarySplitCandidate(org.knime.base.node.mine.treeensemble2.learner.NominalBinarySplitCandidate) SplitCandidate(org.knime.base.node.mine.treeensemble2.learner.SplitCandidate) DefaultDataIndexManager(org.knime.base.node.mine.treeensemble2.data.memberships.DefaultDataIndexManager) DataMemberships(org.knime.base.node.mine.treeensemble2.data.memberships.DataMemberships) RootDataMemberships(org.knime.base.node.mine.treeensemble2.data.memberships.RootDataMemberships) NominalMultiwaySplitCandidate(org.knime.base.node.mine.treeensemble2.learner.NominalMultiwaySplitCandidate) Test(org.junit.Test)

Example 9 with RegressionPriors

use of org.knime.base.node.mine.treeensemble2.data.RegressionPriors in project knime-core by knime.

the class TreeNominalColumnDataTest method testCalcBestSplitRegressionBinary.

/**
 * Tests the method
 * {@link TreeNominalColumnData#calcBestSplitRegression(DataMemberships, RegressionPriors, TreeTargetNumericColumnData, RandomData)}
 * using binary splits
 *
 * @throws Exception
 */
@Test
public void testCalcBestSplitRegressionBinary() throws Exception {
    TreeEnsembleLearnerConfiguration config = new TreeEnsembleLearnerConfiguration(true);
    Pair<TreeNominalColumnData, TreeTargetNumericColumnData> tennisDataRegression = tennisDataRegression(config);
    TreeNominalColumnData columnData = tennisDataRegression.getFirst();
    TreeTargetNumericColumnData targetData = tennisDataRegression.getSecond();
    TreeData treeData = createTreeDataRegression(tennisDataRegression);
    double[] rowWeights = new double[SMALL_COLUMN_DATA.length];
    Arrays.fill(rowWeights, 1.0);
    IDataIndexManager indexManager = new DefaultDataIndexManager(treeData);
    DataMemberships dataMemberships = new RootDataMemberships(rowWeights, treeData, indexManager);
    RegressionPriors priors = targetData.getPriors(rowWeights, config);
    SplitCandidate splitCandidate = columnData.calcBestSplitRegression(dataMemberships, priors, targetData, null);
    assertNotNull(splitCandidate);
    assertThat(splitCandidate, instanceOf(NominalBinarySplitCandidate.class));
    assertTrue(splitCandidate.canColumnBeSplitFurther());
    assertEquals(32.9143, splitCandidate.getGainValue(), 0.0001);
    NominalBinarySplitCandidate binarySplitCandidate = (NominalBinarySplitCandidate) splitCandidate;
    TreeNodeNominalBinaryCondition[] childConditions = binarySplitCandidate.getChildConditions();
    assertEquals(2, childConditions.length);
    assertArrayEquals(new String[] { "R" }, childConditions[0].getValues());
    assertArrayEquals(new String[] { "R" }, childConditions[1].getValues());
    assertEquals(SetLogic.IS_NOT_IN, childConditions[0].getSetLogic());
    assertEquals(SetLogic.IS_IN, childConditions[1].getSetLogic());
}
Also used : TreeEnsembleLearnerConfiguration(org.knime.base.node.mine.treeensemble2.node.learner.TreeEnsembleLearnerConfiguration) RootDataMemberships(org.knime.base.node.mine.treeensemble2.data.memberships.RootDataMemberships) IDataIndexManager(org.knime.base.node.mine.treeensemble2.data.memberships.IDataIndexManager) NominalMultiwaySplitCandidate(org.knime.base.node.mine.treeensemble2.learner.NominalMultiwaySplitCandidate) NominalBinarySplitCandidate(org.knime.base.node.mine.treeensemble2.learner.NominalBinarySplitCandidate) SplitCandidate(org.knime.base.node.mine.treeensemble2.learner.SplitCandidate) DefaultDataIndexManager(org.knime.base.node.mine.treeensemble2.data.memberships.DefaultDataIndexManager) DataMemberships(org.knime.base.node.mine.treeensemble2.data.memberships.DataMemberships) RootDataMemberships(org.knime.base.node.mine.treeensemble2.data.memberships.RootDataMemberships) TreeNodeNominalBinaryCondition(org.knime.base.node.mine.treeensemble2.model.TreeNodeNominalBinaryCondition) NominalBinarySplitCandidate(org.knime.base.node.mine.treeensemble2.learner.NominalBinarySplitCandidate) Test(org.junit.Test)

Example 10 with RegressionPriors

use of org.knime.base.node.mine.treeensemble2.data.RegressionPriors in project knime-core by knime.

the class TreeNumericColumnData method calcBestSplitRegression.

@Override
public SplitCandidate calcBestSplitRegression(final DataMemberships dataMemberships, final RegressionPriors targetPriors, final TreeTargetNumericColumnData targetColumn, final RandomData rd) {
    final TreeEnsembleLearnerConfiguration config = getConfiguration();
    final boolean useAverageSplitPoints = config.isUseAverageSplitPoints();
    final int minChildNodeSize = config.getMinChildSize();
    // get columnMemberships
    final ColumnMemberships columnMemberships = dataMemberships.getColumnMemberships(getMetaData().getAttributeIndex());
    final int lengthNonMissing = getLengthNonMissing();
    // missing value handling
    final boolean useXGBoostMissingValueHandling = config.getMissingValueHandling() == MissingValueHandling.XGBoost;
    // are there missing values in this column (complete column)
    boolean branchContainsMissingValues = containsMissingValues();
    boolean missingsGoLeft = true;
    double missingWeight = 0.0;
    double missingY = 0.0;
    // check if there are missing values in this rowsample
    if (branchContainsMissingValues) {
        columnMemberships.goToLast();
        while (columnMemberships.getIndexInColumn() >= lengthNonMissing) {
            missingWeight += columnMemberships.getRowWeight();
            missingY += targetColumn.getValueFor(columnMemberships.getOriginalIndex());
            if (!columnMemberships.previous()) {
                break;
            }
        }
        columnMemberships.reset();
        branchContainsMissingValues = missingWeight > 0.0;
    }
    final double ySumTotal = targetPriors.getYSum() - missingY;
    final double nrRecordsTotal = targetPriors.getNrRecords() - missingWeight;
    final double criterionTotal = useXGBoostMissingValueHandling ? (ySumTotal + missingY) * (ySumTotal + missingY) / (nrRecordsTotal + missingWeight) : ySumTotal * ySumTotal / nrRecordsTotal;
    double ySumLeft = 0.0;
    double nrRecordsLeft = 0.0;
    double ySumRight = ySumTotal;
    double nrRecordsRight = nrRecordsTotal;
    // all values in the current branch are missing
    if (nrRecordsRight == 0) {
        // it is impossible to determine a split
        return null;
    }
    double bestSplit = Double.NEGATIVE_INFINITY;
    double bestImprovement = 0.0;
    double lastSeenY = Double.NaN;
    double lastSeenValue = Double.NEGATIVE_INFINITY;
    double lastSeenWeight = -1.0;
    // compute the gain, keep the one that maximizes the split
    while (columnMemberships.next()) {
        final double weight = columnMemberships.getRowWeight();
        if (weight < EPSILON) {
            // ignore record: not in current branch or not in sample
            continue;
        } else if (Math.floor(weight) != weight) {
            throw new UnsupportedOperationException("weighted records (missing values?) not supported, " + "weight is " + weight);
        }
        final double value = getSorted(columnMemberships.getIndexInColumn());
        if (lastSeenWeight > 0.0) {
            ySumLeft += lastSeenWeight * lastSeenY;
            ySumRight -= lastSeenWeight * lastSeenY;
            nrRecordsLeft += lastSeenWeight;
            nrRecordsRight -= lastSeenWeight;
            if (nrRecordsLeft >= minChildNodeSize && nrRecordsRight >= minChildNodeSize && lastSeenValue < value) {
                boolean tempMissingsGoLeft = true;
                double childrenSquaredSum;
                if (branchContainsMissingValues && useXGBoostMissingValueHandling) {
                    final double[] tempChildrenSquaredSum = new double[2];
                    tempChildrenSquaredSum[0] = ((ySumLeft + missingY) * (ySumLeft + missingY) / (nrRecordsLeft + missingWeight)) + (ySumRight * ySumRight / nrRecordsRight);
                    tempChildrenSquaredSum[1] = (ySumLeft * ySumLeft / nrRecordsLeft) + ((ySumRight + missingY) * (ySumRight + missingY) / (nrRecordsRight + missingWeight));
                    if (tempChildrenSquaredSum[0] >= tempChildrenSquaredSum[1]) {
                        childrenSquaredSum = tempChildrenSquaredSum[0];
                        tempMissingsGoLeft = true;
                    } else {
                        childrenSquaredSum = tempChildrenSquaredSum[1];
                        tempMissingsGoLeft = false;
                    }
                } else {
                    childrenSquaredSum = (ySumLeft * ySumLeft / nrRecordsLeft) + (ySumRight * ySumRight / nrRecordsRight);
                }
                double criterion = childrenSquaredSum - criterionTotal;
                boolean randomTieBreaker = criterion == bestImprovement ? rd.nextInt(0, 1) == 1 : false;
                if (criterion > bestImprovement || randomTieBreaker) {
                    bestImprovement = criterion;
                    bestSplit = useAverageSplitPoints ? getCenter(lastSeenValue, value) : lastSeenValue;
                    // if there are no missing values go with majority
                    missingsGoLeft = branchContainsMissingValues ? tempMissingsGoLeft : nrRecordsLeft >= nrRecordsRight;
                }
            }
        }
        lastSeenY = targetColumn.getValueFor(columnMemberships.getOriginalIndex());
        lastSeenValue = value;
        lastSeenWeight = weight;
    }
    // + " but was " + lastSeenY * lastSeenWeight;
    if (bestImprovement > 0.0) {
        if (useXGBoostMissingValueHandling) {
            // return new NumericMissingSplitCandidate(this, bestSplit, bestImprovement, missingsGoLeft);
            return new NumericSplitCandidate(this, bestSplit, bestImprovement, new BitSet(), missingsGoLeft ? NumericSplitCandidate.MISSINGS_GO_LEFT : NumericSplitCandidate.MISSINGS_GO_RIGHT);
        }
        return new NumericSplitCandidate(this, bestSplit, bestImprovement, getMissedRows(columnMemberships), NumericSplitCandidate.NO_MISSINGS);
    } else {
        return null;
    }
}
Also used : TreeEnsembleLearnerConfiguration(org.knime.base.node.mine.treeensemble2.node.learner.TreeEnsembleLearnerConfiguration) NumericSplitCandidate(org.knime.base.node.mine.treeensemble2.learner.NumericSplitCandidate) BitSet(java.util.BitSet) ColumnMemberships(org.knime.base.node.mine.treeensemble2.data.memberships.ColumnMemberships)

Aggregations

TreeEnsembleLearnerConfiguration (org.knime.base.node.mine.treeensemble2.node.learner.TreeEnsembleLearnerConfiguration)8 DataMemberships (org.knime.base.node.mine.treeensemble2.data.memberships.DataMemberships)5 RootDataMemberships (org.knime.base.node.mine.treeensemble2.data.memberships.RootDataMemberships)5 TreeData (org.knime.base.node.mine.treeensemble2.data.TreeData)4 TreeTargetNumericColumnData (org.knime.base.node.mine.treeensemble2.data.TreeTargetNumericColumnData)4 NominalBinarySplitCandidate (org.knime.base.node.mine.treeensemble2.learner.NominalBinarySplitCandidate)4 BitSet (java.util.BitSet)3 RandomData (org.apache.commons.math.random.RandomData)3 Test (org.junit.Test)3 TreeAttributeColumnData (org.knime.base.node.mine.treeensemble2.data.TreeAttributeColumnData)3 ColumnMemberships (org.knime.base.node.mine.treeensemble2.data.memberships.ColumnMemberships)3 DefaultDataIndexManager (org.knime.base.node.mine.treeensemble2.data.memberships.DefaultDataIndexManager)3 IDataIndexManager (org.knime.base.node.mine.treeensemble2.data.memberships.IDataIndexManager)3 BigInteger (java.math.BigInteger)2 RegressionPriors (org.knime.base.node.mine.treeensemble2.data.RegressionPriors)2 NominalMultiwaySplitCandidate (org.knime.base.node.mine.treeensemble2.learner.NominalMultiwaySplitCandidate)2 SplitCandidate (org.knime.base.node.mine.treeensemble2.learner.SplitCandidate)2 TreeNodeRegression (org.knime.base.node.mine.treeensemble2.model.TreeNodeRegression)2 TreeNodeSignature (org.knime.base.node.mine.treeensemble2.model.TreeNodeSignature)2 GradientBoostingLearnerConfiguration (org.knime.base.node.mine.treeensemble2.node.gradientboosting.learner.GradientBoostingLearnerConfiguration)2