Search in sources :

Example 16 with TreeData

use of org.knime.base.node.mine.treeensemble2.data.TreeData in project knime-core by knime.

the class TreeLearnerClassification method learnSingleTreeRecursive.

private TreeModelClassification learnSingleTreeRecursive(final ExecutionMonitor exec, final RandomData rd) throws CanceledExecutionException {
    final TreeData data = getData();
    final RowSample rowSampling = getRowSampling();
    final TreeEnsembleLearnerConfiguration config = getConfig();
    final TreeTargetNominalColumnData targetColumn = (TreeTargetNominalColumnData) data.getTargetColumn();
    final // new RootDataMem(rowSampling, getIndexManager());
    DataMemberships rootDataMemberships = new RootDataMemberships(rowSampling, data, getIndexManager());
    ClassificationPriors targetPriors = targetColumn.getDistribution(rootDataMemberships, config);
    BitSet forbiddenColumnSet = new BitSet(data.getNrAttributes());
    // final DataMemberships rootDataMemberships = new IntArrayDataMemberships(sampleWeights, data);
    final TreeNodeSignature rootSignature = TreeNodeSignature.ROOT_SIGNATURE;
    final ColumnSample rootColumnSample = getColSamplingStrategy().getColumnSampleForTreeNode(rootSignature);
    TreeNodeClassification rootNode = null;
    rootNode = buildTreeNode(exec, 0, rootDataMemberships, rootColumnSample, rootSignature, targetPriors, forbiddenColumnSet);
    assert forbiddenColumnSet.cardinality() == 0;
    rootNode.setTreeNodeCondition(TreeNodeTrueCondition.INSTANCE);
    return new TreeModelClassification(rootNode);
}
Also used : TreeEnsembleLearnerConfiguration(org.knime.base.node.mine.treeensemble2.node.learner.TreeEnsembleLearnerConfiguration) RootDataMemberships(org.knime.base.node.mine.treeensemble2.data.memberships.RootDataMemberships) TreeNodeClassification(org.knime.base.node.mine.treeensemble2.model.TreeNodeClassification) ColumnSample(org.knime.base.node.mine.treeensemble2.sample.column.ColumnSample) BitSet(java.util.BitSet) TreeData(org.knime.base.node.mine.treeensemble2.data.TreeData) RowSample(org.knime.base.node.mine.treeensemble2.sample.row.RowSample) TreeNodeSignature(org.knime.base.node.mine.treeensemble2.model.TreeNodeSignature) TreeTargetNominalColumnData(org.knime.base.node.mine.treeensemble2.data.TreeTargetNominalColumnData) ClassificationPriors(org.knime.base.node.mine.treeensemble2.data.ClassificationPriors) TreeModelClassification(org.knime.base.node.mine.treeensemble2.model.TreeModelClassification)

Example 17 with TreeData

use of org.knime.base.node.mine.treeensemble2.data.TreeData in project knime-core by knime.

the class AbstractGradientBoostingLearner method createResidualDataFromArray.

/**
 * Creates a {@link TreeData} object that uses the values in <b>residualData</b> as target.
 *
 * @param residualData array containing the residuals
 * @param actualData the TreeData as it is provided by the user
 * @return data using the residuals as targets
 */
protected TreeData createResidualDataFromArray(final double[] residualData, final TreeData actualData) {
    TreeTargetNumericColumnData actual = (TreeTargetNumericColumnData) actualData.getTargetColumn();
    RowKey[] rowKeysAsArray = new RowKey[actual.getNrRows()];
    for (int i = 0; i < rowKeysAsArray.length; i++) {
        rowKeysAsArray[i] = actual.getRowKeyFor(i);
    }
    TreeTargetNumericColumnMetaData metaData = actual.getMetaData();
    TreeTargetNumericColumnData residualTarget = new TreeTargetNumericColumnData(metaData, rowKeysAsArray, residualData);
    return new TreeData(getData().getColumns(), residualTarget, getData().getTreeType());
}
Also used : RowKey(org.knime.core.data.RowKey) TreeTargetNumericColumnData(org.knime.base.node.mine.treeensemble2.data.TreeTargetNumericColumnData) TreeTargetNumericColumnMetaData(org.knime.base.node.mine.treeensemble2.data.TreeTargetNumericColumnMetaData) TreeData(org.knime.base.node.mine.treeensemble2.data.TreeData)

Example 18 with TreeData

use of org.knime.base.node.mine.treeensemble2.data.TreeData in project knime-core by knime.

the class MGradientBoostedTreesLearner method learn.

/**
 * {@inheritDoc}
 */
@Override
public AbstractGradientBoostingModel learn(final ExecutionMonitor exec) throws CanceledExecutionException {
    final TreeData actualData = getData();
    final GradientBoostingLearnerConfiguration config = getConfig();
    final int nrModels = config.getNrModels();
    final TreeTargetNumericColumnData actualTarget = getTarget();
    final double initialValue = actualTarget.getMedian();
    final ArrayList<TreeModelRegression> models = new ArrayList<TreeModelRegression>(nrModels);
    final ArrayList<Map<TreeNodeSignature, Double>> coefficientMaps = new ArrayList<Map<TreeNodeSignature, Double>>(nrModels);
    final double[] previousPrediction = new double[actualTarget.getNrRows()];
    Arrays.fill(previousPrediction, initialValue);
    final RandomData rd = config.createRandomData();
    final double alpha = config.getAlpha();
    TreeNodeSignatureFactory signatureFactory = null;
    final int maxLevels = config.getMaxLevels();
    // this should be the default
    if (maxLevels < TreeEnsembleLearnerConfiguration.MAX_LEVEL_INFINITE) {
        final int capacity = IntMath.pow(2, maxLevels - 1);
        signatureFactory = new TreeNodeSignatureFactory(capacity);
    } else {
        signatureFactory = new TreeNodeSignatureFactory();
    }
    exec.setMessage("Learning model");
    TreeData residualData;
    for (int i = 0; i < nrModels; i++) {
        final double[] residuals = new double[actualTarget.getNrRows()];
        for (int j = 0; j < actualTarget.getNrRows(); j++) {
            residuals[j] = actualTarget.getValueFor(j) - previousPrediction[j];
        }
        final double quantile = calculateAlphaQuantile(residuals, alpha);
        final double[] gradients = new double[residuals.length];
        for (int j = 0; j < gradients.length; j++) {
            gradients[j] = Math.abs(residuals[j]) <= quantile ? residuals[j] : quantile * Math.signum(residuals[j]);
        }
        residualData = createResidualDataFromArray(gradients, actualData);
        final RandomData rdSingle = TreeEnsembleLearnerConfiguration.createRandomData(rd.nextLong(Long.MIN_VALUE, Long.MAX_VALUE));
        final RowSample rowSample = getRowSampler().createRowSample(rdSingle);
        final TreeLearnerRegression treeLearner = new TreeLearnerRegression(getConfig(), residualData, getIndexManager(), signatureFactory, rdSingle, rowSample);
        final TreeModelRegression tree = treeLearner.learnSingleTree(exec, rdSingle);
        final Map<TreeNodeSignature, Double> coefficientMap = calcCoefficientMap(residuals, quantile, tree);
        adaptPreviousPrediction(previousPrediction, tree, coefficientMap);
        models.add(tree);
        coefficientMaps.add(coefficientMap);
        exec.setProgress(((double) i) / nrModels, "Finished level " + i + "/" + nrModels);
    }
    return new GradientBoostedTreesModel(getConfig(), actualData.getMetaData(), models.toArray(new TreeModelRegression[models.size()]), actualData.getTreeType(), initialValue, coefficientMaps);
}
Also used : RandomData(org.apache.commons.math.random.RandomData) ArrayList(java.util.ArrayList) TreeTargetNumericColumnData(org.knime.base.node.mine.treeensemble2.data.TreeTargetNumericColumnData) GradientBoostedTreesModel(org.knime.base.node.mine.treeensemble2.model.GradientBoostedTreesModel) TreeNodeSignature(org.knime.base.node.mine.treeensemble2.model.TreeNodeSignature) TreeModelRegression(org.knime.base.node.mine.treeensemble2.model.TreeModelRegression) GradientBoostingLearnerConfiguration(org.knime.base.node.mine.treeensemble2.node.gradientboosting.learner.GradientBoostingLearnerConfiguration) TreeData(org.knime.base.node.mine.treeensemble2.data.TreeData) TreeLearnerRegression(org.knime.base.node.mine.treeensemble2.learner.TreeLearnerRegression) RowSample(org.knime.base.node.mine.treeensemble2.sample.row.RowSample) HashMap(java.util.HashMap) Map(java.util.Map) TreeNodeSignatureFactory(org.knime.base.node.mine.treeensemble2.learner.TreeNodeSignatureFactory)

Example 19 with TreeData

use of org.knime.base.node.mine.treeensemble2.data.TreeData in project knime-core by knime.

the class TreeNumericColumnDataTest method testCalcBestSplitClassification.

@Test
public void testCalcBestSplitClassification() throws Exception {
    TreeEnsembleLearnerConfiguration config = createConfig();
    /* data from J. Fuernkranz, Uni Darmstadt:
         * http://www.ke.tu-darmstadt.de/lehre/archiv/ws0809/mldm/dt.pdf */
    final double[] data = asDataArray("60,70,75,85, 90, 95, 100,120,125,220");
    final String[] target = asStringArray("No,No,No,Yes,Yes,Yes,No, No, No, No");
    Pair<TreeOrdinaryNumericColumnData, TreeTargetNominalColumnData> exampleData = exampleData(config, data, target);
    RandomData rd = config.createRandomData();
    TreeNumericColumnData columnData = exampleData.getFirst();
    TreeTargetNominalColumnData targetData = exampleData.getSecond();
    assertEquals(SplitCriterion.Gini, config.getSplitCriterion());
    double[] rowWeights = new double[data.length];
    Arrays.fill(rowWeights, 1.0);
    TreeData treeData = createTreeDataClassification(exampleData);
    IDataIndexManager indexManager = new DefaultDataIndexManager(treeData);
    DataMemberships dataMemberships = new RootDataMemberships(rowWeights, treeData, indexManager);
    ClassificationPriors priors = targetData.getDistribution(rowWeights, config);
    SplitCandidate splitCandidate = columnData.calcBestSplitClassification(dataMemberships, priors, targetData, rd);
    assertNotNull(splitCandidate);
    assertThat(splitCandidate, instanceOf(NumericSplitCandidate.class));
    assertTrue(splitCandidate.canColumnBeSplitFurther());
    // libre office calc
    assertEquals(/*0.42 - 0.300 */
    0.12, splitCandidate.getGainValue(), 0.00001);
    NumericSplitCandidate numSplitCandidate = (NumericSplitCandidate) splitCandidate;
    TreeNodeNumericCondition[] childConditions = numSplitCandidate.getChildConditions();
    assertEquals(2, childConditions.length);
    assertEquals((95.0 + 100.0) / 2.0, childConditions[0].getSplitValue(), 0.0);
    assertEquals((95.0 + 100.0) / 2.0, childConditions[1].getSplitValue(), 0.0);
    assertEquals(NumericOperator.LessThanOrEqual, childConditions[0].getNumericOperator());
    assertEquals(NumericOperator.LargerThan, childConditions[1].getNumericOperator());
    double[] childRowWeights = new double[data.length];
    System.arraycopy(rowWeights, 0, childRowWeights, 0, rowWeights.length);
    BitSet inChild = columnData.updateChildMemberships(childConditions[0], dataMemberships);
    DataMemberships childMemberships = dataMemberships.createChildMemberships(inChild);
    ClassificationPriors childTargetPriors = targetData.getDistribution(childMemberships, config);
    SplitCandidate splitCandidateChild = columnData.calcBestSplitClassification(childMemberships, childTargetPriors, targetData, rd);
    assertNotNull(splitCandidateChild);
    assertThat(splitCandidateChild, instanceOf(NumericSplitCandidate.class));
    // manually via libre office calc
    assertEquals(0.5, splitCandidateChild.getGainValue(), 0.00001);
    TreeNodeNumericCondition[] childConditions2 = ((NumericSplitCandidate) splitCandidateChild).getChildConditions();
    assertEquals(2, childConditions2.length);
    assertEquals((75.0 + 85.0) / 2.0, childConditions2[0].getSplitValue(), 0.0);
    System.arraycopy(rowWeights, 0, childRowWeights, 0, rowWeights.length);
    inChild = columnData.updateChildMemberships(childConditions[1], dataMemberships);
    childMemberships = dataMemberships.createChildMemberships(inChild);
    childTargetPriors = targetData.getDistribution(childMemberships, config);
    splitCandidateChild = columnData.calcBestSplitClassification(childMemberships, childTargetPriors, targetData, rd);
    assertNull(splitCandidateChild);
}
Also used : TreeEnsembleLearnerConfiguration(org.knime.base.node.mine.treeensemble2.node.learner.TreeEnsembleLearnerConfiguration) RootDataMemberships(org.knime.base.node.mine.treeensemble2.data.memberships.RootDataMemberships) RandomData(org.apache.commons.math.random.RandomData) TreeNodeNumericCondition(org.knime.base.node.mine.treeensemble2.model.TreeNodeNumericCondition) BitSet(java.util.BitSet) IDataIndexManager(org.knime.base.node.mine.treeensemble2.data.memberships.IDataIndexManager) NumericSplitCandidate(org.knime.base.node.mine.treeensemble2.learner.NumericSplitCandidate) SplitCandidate(org.knime.base.node.mine.treeensemble2.learner.SplitCandidate) NumericMissingSplitCandidate(org.knime.base.node.mine.treeensemble2.learner.NumericMissingSplitCandidate) DefaultDataIndexManager(org.knime.base.node.mine.treeensemble2.data.memberships.DefaultDataIndexManager) DataMemberships(org.knime.base.node.mine.treeensemble2.data.memberships.DataMemberships) RootDataMemberships(org.knime.base.node.mine.treeensemble2.data.memberships.RootDataMemberships) NumericSplitCandidate(org.knime.base.node.mine.treeensemble2.learner.NumericSplitCandidate) Test(org.junit.Test)

Example 20 with TreeData

use of org.knime.base.node.mine.treeensemble2.data.TreeData in project knime-core by knime.

the class TreeNumericColumnDataTest method testCalcBestSplitClassificationSplitAtEnd.

/**
 * Test splits at last possible split position - even if no change in target can be observed, see example data in
 * method body.
 * @throws Exception
 */
@Test
public void testCalcBestSplitClassificationSplitAtEnd() throws Exception {
    // Index:  1 2 3 4 5 6 7 8
    // Value:  1 1|2 2 2|3 3 3
    // Target: A A|A A A|A A B
    double[] data = asDataArray("1,1,2,2,2,3,3,3");
    String[] target = asStringArray("A,A,A,A,A,A,A,B");
    TreeEnsembleLearnerConfiguration config = createConfig();
    RandomData rd = config.createRandomData();
    Pair<TreeOrdinaryNumericColumnData, TreeTargetNominalColumnData> exampleData = exampleData(config, data, target);
    TreeNumericColumnData columnData = exampleData.getFirst();
    TreeTargetNominalColumnData targetData = exampleData.getSecond();
    double[] rowWeights = new double[data.length];
    Arrays.fill(rowWeights, 1.0);
    TreeData treeData = createTreeDataClassification(exampleData);
    IDataIndexManager indexManager = new DefaultDataIndexManager(treeData);
    DataMemberships dataMemberships = new RootDataMemberships(rowWeights, treeData, indexManager);
    ClassificationPriors priors = targetData.getDistribution(rowWeights, config);
    SplitCandidate splitCandidate = columnData.calcBestSplitClassification(dataMemberships, priors, targetData, rd);
    assertNotNull(splitCandidate);
    assertThat(splitCandidate, instanceOf(NumericSplitCandidate.class));
    assertTrue(splitCandidate.canColumnBeSplitFurther());
    // manually calculated
    assertEquals(/*0.21875 - 0.166666667 */
    0.05208, splitCandidate.getGainValue(), 0.001);
    NumericSplitCandidate numSplitCandidate = (NumericSplitCandidate) splitCandidate;
    TreeNodeNumericCondition[] childConditions = numSplitCandidate.getChildConditions();
    assertEquals(2, childConditions.length);
    assertEquals((2.0 + 3.0) / 2.0, childConditions[0].getSplitValue(), 0.0);
    assertEquals(NumericOperator.LessThanOrEqual, childConditions[0].getNumericOperator());
    double[] childRowWeights = new double[data.length];
    System.arraycopy(rowWeights, 0, childRowWeights, 0, rowWeights.length);
    BitSet inChild = columnData.updateChildMemberships(childConditions[0], dataMemberships);
    DataMemberships childMemberships = dataMemberships.createChildMemberships(inChild);
    ClassificationPriors childTargetPriors = targetData.getDistribution(childMemberships, config);
    SplitCandidate splitCandidateChild = columnData.calcBestSplitClassification(childMemberships, childTargetPriors, targetData, rd);
    assertNull(splitCandidateChild);
    System.arraycopy(rowWeights, 0, childRowWeights, 0, rowWeights.length);
    inChild = columnData.updateChildMemberships(childConditions[1], dataMemberships);
    childMemberships = dataMemberships.createChildMemberships(inChild);
    childTargetPriors = targetData.getDistribution(childMemberships, config);
    splitCandidateChild = columnData.calcBestSplitClassification(childMemberships, childTargetPriors, targetData, null);
    assertNull(splitCandidateChild);
}
Also used : TreeEnsembleLearnerConfiguration(org.knime.base.node.mine.treeensemble2.node.learner.TreeEnsembleLearnerConfiguration) RootDataMemberships(org.knime.base.node.mine.treeensemble2.data.memberships.RootDataMemberships) RandomData(org.apache.commons.math.random.RandomData) TreeNodeNumericCondition(org.knime.base.node.mine.treeensemble2.model.TreeNodeNumericCondition) BitSet(java.util.BitSet) IDataIndexManager(org.knime.base.node.mine.treeensemble2.data.memberships.IDataIndexManager) NumericSplitCandidate(org.knime.base.node.mine.treeensemble2.learner.NumericSplitCandidate) SplitCandidate(org.knime.base.node.mine.treeensemble2.learner.SplitCandidate) NumericMissingSplitCandidate(org.knime.base.node.mine.treeensemble2.learner.NumericMissingSplitCandidate) DefaultDataIndexManager(org.knime.base.node.mine.treeensemble2.data.memberships.DefaultDataIndexManager) DataMemberships(org.knime.base.node.mine.treeensemble2.data.memberships.DataMemberships) RootDataMemberships(org.knime.base.node.mine.treeensemble2.data.memberships.RootDataMemberships) NumericSplitCandidate(org.knime.base.node.mine.treeensemble2.learner.NumericSplitCandidate) Test(org.junit.Test)

Aggregations

TreeData (org.knime.base.node.mine.treeensemble2.data.TreeData)27 TreeEnsembleLearnerConfiguration (org.knime.base.node.mine.treeensemble2.node.learner.TreeEnsembleLearnerConfiguration)25 Test (org.junit.Test)18 DataMemberships (org.knime.base.node.mine.treeensemble2.data.memberships.DataMemberships)18 RootDataMemberships (org.knime.base.node.mine.treeensemble2.data.memberships.RootDataMemberships)18 DefaultDataIndexManager (org.knime.base.node.mine.treeensemble2.data.memberships.DefaultDataIndexManager)15 IDataIndexManager (org.knime.base.node.mine.treeensemble2.data.memberships.IDataIndexManager)15 RandomData (org.apache.commons.math.random.RandomData)14 TreeAttributeColumnData (org.knime.base.node.mine.treeensemble2.data.TreeAttributeColumnData)12 SplitCandidate (org.knime.base.node.mine.treeensemble2.learner.SplitCandidate)12 BitSet (java.util.BitSet)11 TreeTargetNumericColumnData (org.knime.base.node.mine.treeensemble2.data.TreeTargetNumericColumnData)8 TreeNodeSignature (org.knime.base.node.mine.treeensemble2.model.TreeNodeSignature)8 ExecutionMonitor (org.knime.core.node.ExecutionMonitor)8 TreeDataCreator (org.knime.base.node.mine.treeensemble2.data.TreeDataCreator)7 TreeTargetNominalColumnData (org.knime.base.node.mine.treeensemble2.data.TreeTargetNominalColumnData)7 NominalBinarySplitCandidate (org.knime.base.node.mine.treeensemble2.learner.NominalBinarySplitCandidate)7 NominalMultiwaySplitCandidate (org.knime.base.node.mine.treeensemble2.learner.NominalMultiwaySplitCandidate)7 FilterLearnColumnRearranger (org.knime.base.node.mine.treeensemble2.node.learner.TreeEnsembleLearnerConfiguration.FilterLearnColumnRearranger)7 TreeEnsembleModelPortObjectSpec (org.knime.base.node.mine.treeensemble2.model.TreeEnsembleModelPortObjectSpec)6