Search in sources :

Example 1 with RowSample

use of org.knime.base.node.mine.treeensemble2.sample.row.RowSample in project knime-core by knime.

the class TreeLearnerRegression method learnSingleTree.

/**
 * {@inheritDoc}
 */
@Override
public TreeModelRegression learnSingleTree(final ExecutionMonitor exec, final RandomData rd) throws CanceledExecutionException {
    final TreeTargetNumericColumnData targetColumn = getTargetData();
    final TreeData data = getData();
    final RowSample rowSampling = getRowSampling();
    final TreeEnsembleLearnerConfiguration config = getConfig();
    final IDataIndexManager indexManager = getIndexManager();
    DataMemberships rootDataMemberships = new RootDataMemberships(rowSampling, data, indexManager);
    RegressionPriors targetPriors = targetColumn.getPriors(rootDataMemberships, config);
    BitSet forbiddenColumnSet = new BitSet(data.getNrAttributes());
    boolean isGradientBoosting = config instanceof GradientBoostingLearnerConfiguration;
    if (isGradientBoosting) {
        m_leafs = new ArrayList<TreeNodeRegression>();
    }
    final TreeNodeSignature rootSignature = TreeNodeSignature.ROOT_SIGNATURE;
    final ColumnSample rootColumnSample = getColSamplingStrategy().getColumnSampleForTreeNode(rootSignature);
    TreeNodeRegression rootNode = buildTreeNode(exec, 0, rootDataMemberships, rootColumnSample, getSignatureFactory().getRootSignature(), targetPriors, forbiddenColumnSet);
    assert forbiddenColumnSet.cardinality() == 0;
    rootNode.setTreeNodeCondition(TreeNodeTrueCondition.INSTANCE);
    if (isGradientBoosting) {
        return new TreeModelRegression(rootNode, m_leafs);
    }
    return new TreeModelRegression(rootNode);
}
Also used : TreeEnsembleLearnerConfiguration(org.knime.base.node.mine.treeensemble2.node.learner.TreeEnsembleLearnerConfiguration) RootDataMemberships(org.knime.base.node.mine.treeensemble2.data.memberships.RootDataMemberships) ColumnSample(org.knime.base.node.mine.treeensemble2.sample.column.ColumnSample) RegressionPriors(org.knime.base.node.mine.treeensemble2.data.RegressionPriors) BitSet(java.util.BitSet) TreeTargetNumericColumnData(org.knime.base.node.mine.treeensemble2.data.TreeTargetNumericColumnData) IDataIndexManager(org.knime.base.node.mine.treeensemble2.data.memberships.IDataIndexManager) TreeNodeSignature(org.knime.base.node.mine.treeensemble2.model.TreeNodeSignature) TreeNodeRegression(org.knime.base.node.mine.treeensemble2.model.TreeNodeRegression) DataMemberships(org.knime.base.node.mine.treeensemble2.data.memberships.DataMemberships) RootDataMemberships(org.knime.base.node.mine.treeensemble2.data.memberships.RootDataMemberships) TreeModelRegression(org.knime.base.node.mine.treeensemble2.model.TreeModelRegression) GradientBoostingLearnerConfiguration(org.knime.base.node.mine.treeensemble2.node.gradientboosting.learner.GradientBoostingLearnerConfiguration) TreeData(org.knime.base.node.mine.treeensemble2.data.TreeData) RowSample(org.knime.base.node.mine.treeensemble2.sample.row.RowSample)

Example 2 with RowSample

use of org.knime.base.node.mine.treeensemble2.sample.row.RowSample in project knime-core by knime.

the class TreeLearnerClassification method learnSingleTreeRecursive.

private TreeModelClassification learnSingleTreeRecursive(final ExecutionMonitor exec, final RandomData rd) throws CanceledExecutionException {
    final TreeData data = getData();
    final RowSample rowSampling = getRowSampling();
    final TreeEnsembleLearnerConfiguration config = getConfig();
    final TreeTargetNominalColumnData targetColumn = (TreeTargetNominalColumnData) data.getTargetColumn();
    final // new RootDataMem(rowSampling, getIndexManager());
    DataMemberships rootDataMemberships = new RootDataMemberships(rowSampling, data, getIndexManager());
    ClassificationPriors targetPriors = targetColumn.getDistribution(rootDataMemberships, config);
    BitSet forbiddenColumnSet = new BitSet(data.getNrAttributes());
    // final DataMemberships rootDataMemberships = new IntArrayDataMemberships(sampleWeights, data);
    final TreeNodeSignature rootSignature = TreeNodeSignature.ROOT_SIGNATURE;
    final ColumnSample rootColumnSample = getColSamplingStrategy().getColumnSampleForTreeNode(rootSignature);
    TreeNodeClassification rootNode = null;
    rootNode = buildTreeNode(exec, 0, rootDataMemberships, rootColumnSample, rootSignature, targetPriors, forbiddenColumnSet);
    assert forbiddenColumnSet.cardinality() == 0;
    rootNode.setTreeNodeCondition(TreeNodeTrueCondition.INSTANCE);
    return new TreeModelClassification(rootNode);
}
Also used : TreeEnsembleLearnerConfiguration(org.knime.base.node.mine.treeensemble2.node.learner.TreeEnsembleLearnerConfiguration) RootDataMemberships(org.knime.base.node.mine.treeensemble2.data.memberships.RootDataMemberships) TreeNodeClassification(org.knime.base.node.mine.treeensemble2.model.TreeNodeClassification) ColumnSample(org.knime.base.node.mine.treeensemble2.sample.column.ColumnSample) BitSet(java.util.BitSet) TreeData(org.knime.base.node.mine.treeensemble2.data.TreeData) RowSample(org.knime.base.node.mine.treeensemble2.sample.row.RowSample) TreeNodeSignature(org.knime.base.node.mine.treeensemble2.model.TreeNodeSignature) TreeTargetNominalColumnData(org.knime.base.node.mine.treeensemble2.data.TreeTargetNominalColumnData) ClassificationPriors(org.knime.base.node.mine.treeensemble2.data.ClassificationPriors) TreeModelClassification(org.knime.base.node.mine.treeensemble2.model.TreeModelClassification)

Example 3 with RowSample

use of org.knime.base.node.mine.treeensemble2.sample.row.RowSample in project knime-core by knime.

the class MGradientBoostedTreesLearner method learn.

/**
 * {@inheritDoc}
 */
@Override
public AbstractGradientBoostingModel learn(final ExecutionMonitor exec) throws CanceledExecutionException {
    final TreeData actualData = getData();
    final GradientBoostingLearnerConfiguration config = getConfig();
    final int nrModels = config.getNrModels();
    final TreeTargetNumericColumnData actualTarget = getTarget();
    final double initialValue = actualTarget.getMedian();
    final ArrayList<TreeModelRegression> models = new ArrayList<TreeModelRegression>(nrModels);
    final ArrayList<Map<TreeNodeSignature, Double>> coefficientMaps = new ArrayList<Map<TreeNodeSignature, Double>>(nrModels);
    final double[] previousPrediction = new double[actualTarget.getNrRows()];
    Arrays.fill(previousPrediction, initialValue);
    final RandomData rd = config.createRandomData();
    final double alpha = config.getAlpha();
    TreeNodeSignatureFactory signatureFactory = null;
    final int maxLevels = config.getMaxLevels();
    // this should be the default
    if (maxLevels < TreeEnsembleLearnerConfiguration.MAX_LEVEL_INFINITE) {
        final int capacity = IntMath.pow(2, maxLevels - 1);
        signatureFactory = new TreeNodeSignatureFactory(capacity);
    } else {
        signatureFactory = new TreeNodeSignatureFactory();
    }
    exec.setMessage("Learning model");
    TreeData residualData;
    for (int i = 0; i < nrModels; i++) {
        final double[] residuals = new double[actualTarget.getNrRows()];
        for (int j = 0; j < actualTarget.getNrRows(); j++) {
            residuals[j] = actualTarget.getValueFor(j) - previousPrediction[j];
        }
        final double quantile = calculateAlphaQuantile(residuals, alpha);
        final double[] gradients = new double[residuals.length];
        for (int j = 0; j < gradients.length; j++) {
            gradients[j] = Math.abs(residuals[j]) <= quantile ? residuals[j] : quantile * Math.signum(residuals[j]);
        }
        residualData = createResidualDataFromArray(gradients, actualData);
        final RandomData rdSingle = TreeEnsembleLearnerConfiguration.createRandomData(rd.nextLong(Long.MIN_VALUE, Long.MAX_VALUE));
        final RowSample rowSample = getRowSampler().createRowSample(rdSingle);
        final TreeLearnerRegression treeLearner = new TreeLearnerRegression(getConfig(), residualData, getIndexManager(), signatureFactory, rdSingle, rowSample);
        final TreeModelRegression tree = treeLearner.learnSingleTree(exec, rdSingle);
        final Map<TreeNodeSignature, Double> coefficientMap = calcCoefficientMap(residuals, quantile, tree);
        adaptPreviousPrediction(previousPrediction, tree, coefficientMap);
        models.add(tree);
        coefficientMaps.add(coefficientMap);
        exec.setProgress(((double) i) / nrModels, "Finished level " + i + "/" + nrModels);
    }
    return new GradientBoostedTreesModel(getConfig(), actualData.getMetaData(), models.toArray(new TreeModelRegression[models.size()]), actualData.getTreeType(), initialValue, coefficientMaps);
}
Also used : RandomData(org.apache.commons.math.random.RandomData) ArrayList(java.util.ArrayList) TreeTargetNumericColumnData(org.knime.base.node.mine.treeensemble2.data.TreeTargetNumericColumnData) GradientBoostedTreesModel(org.knime.base.node.mine.treeensemble2.model.GradientBoostedTreesModel) TreeNodeSignature(org.knime.base.node.mine.treeensemble2.model.TreeNodeSignature) TreeModelRegression(org.knime.base.node.mine.treeensemble2.model.TreeModelRegression) GradientBoostingLearnerConfiguration(org.knime.base.node.mine.treeensemble2.node.gradientboosting.learner.GradientBoostingLearnerConfiguration) TreeData(org.knime.base.node.mine.treeensemble2.data.TreeData) TreeLearnerRegression(org.knime.base.node.mine.treeensemble2.learner.TreeLearnerRegression) RowSample(org.knime.base.node.mine.treeensemble2.sample.row.RowSample) HashMap(java.util.HashMap) Map(java.util.Map) TreeNodeSignatureFactory(org.knime.base.node.mine.treeensemble2.learner.TreeNodeSignatureFactory)

Example 4 with RowSample

use of org.knime.base.node.mine.treeensemble2.sample.row.RowSample in project knime-core by knime.

the class RootDescendantDataMembershipsTest method testCreateChildDataMemberships.

@Test
public void testCreateChildDataMemberships() {
    TreeEnsembleLearnerConfiguration config = new TreeEnsembleLearnerConfiguration(false);
    TestDataGenerator dataGen = new TestDataGenerator(config);
    TreeData data = dataGen.createTennisData();
    DefaultDataIndexManager indexManager = new DefaultDataIndexManager(data);
    int nrRows = data.getNrRows();
    RowSample rowSample = new DefaultRowSample(nrRows);
    RootDataMemberships rootMemberships = new RootDataMemberships(rowSample, data, indexManager);
    BitSet firstHalf = new BitSet(nrRows);
    firstHalf.set(0, nrRows / 2);
    DataMemberships firstHalfChildMemberships = rootMemberships.createChildMemberships(firstHalf);
    assertThat(firstHalfChildMemberships, instanceOf(BitSetDescendantDataMemberships.class));
    BitSetDescendantDataMemberships bitSetFirstHalfChildMemberships = (BitSetDescendantDataMemberships) firstHalfChildMemberships;
    assertEquals(firstHalf, bitSetFirstHalfChildMemberships.getBitSet());
    BitSet firstQuarter = new BitSet(nrRows);
    firstQuarter.set(0, nrRows / 4);
    DataMemberships firstQuarterGrandChild = firstHalfChildMemberships.createChildMemberships(firstQuarter);
    assertThat(firstQuarterGrandChild, instanceOf(BitSetDescendantDataMemberships.class));
    BitSetDescendantDataMemberships bitSetFirstQuarterGrandChild = (BitSetDescendantDataMemberships) firstQuarterGrandChild;
    assertEquals(firstQuarter, bitSetFirstQuarterGrandChild.getBitSet());
}
Also used : TreeEnsembleLearnerConfiguration(org.knime.base.node.mine.treeensemble2.node.learner.TreeEnsembleLearnerConfiguration) DefaultRowSample(org.knime.base.node.mine.treeensemble2.sample.row.DefaultRowSample) BitSet(java.util.BitSet) TreeData(org.knime.base.node.mine.treeensemble2.data.TreeData) DefaultRowSample(org.knime.base.node.mine.treeensemble2.sample.row.DefaultRowSample) RowSample(org.knime.base.node.mine.treeensemble2.sample.row.RowSample) TestDataGenerator(org.knime.base.node.mine.treeensemble2.data.TestDataGenerator) Test(org.junit.Test)

Example 5 with RowSample

use of org.knime.base.node.mine.treeensemble2.sample.row.RowSample in project knime-core by knime.

the class RootDescendantDataMembershipsTest method testGetColumnMemberships.

@Test
public void testGetColumnMemberships() {
    TreeEnsembleLearnerConfiguration config = new TreeEnsembleLearnerConfiguration(false);
    TestDataGenerator dataGen = new TestDataGenerator(config);
    TreeData data = dataGen.createTennisData();
    DefaultDataIndexManager indexManager = new DefaultDataIndexManager(data);
    int nrRows = data.getNrRows();
    RowSample rowSample = new DefaultRowSample(nrRows);
    RootDataMemberships rootMemberships = new RootDataMemberships(rowSample, data, indexManager);
    ColumnMemberships rootColMem = rootMemberships.getColumnMemberships(0);
    assertThat(rootColMem, instanceOf(IntArrayColumnMemberships.class));
    assertEquals(nrRows, rootColMem.size());
    int[] expectedOriginalIndices = new int[] { 0, 1, 7, 8, 10, 2, 6, 11, 12, 3, 4, 5, 9, 13 };
    for (int i = 0; rootColMem.next(); i++) {
        // in this case originalIndex and indexInDataMemberships are the same
        assertEquals(expectedOriginalIndices[i], rootColMem.getIndexInDataMemberships());
        assertEquals(expectedOriginalIndices[i], rootColMem.getIndexInDataMemberships());
        assertEquals(i, rootColMem.getIndexInColumn());
    }
    BitSet lastHalf = new BitSet(nrRows);
    lastHalf.set(nrRows / 2, nrRows);
    DataMemberships lastHalfChild = rootMemberships.createChildMemberships(lastHalf);
    ColumnMemberships childColMem = lastHalfChild.getColumnMemberships(0);
    assertThat(childColMem, instanceOf(DescendantColumnMemberships.class));
    assertEquals(nrRows / 2, childColMem.size());
    expectedOriginalIndices = new int[] { 7, 8, 10, 11, 12, 9, 13 };
    int[] expectedIndexInColumn = new int[] { 2, 3, 4, 7, 8, 12, 13 };
    int[] expectedIndexInDataMemberships = new int[] { 7, 8, 10, 11, 12, 9, 13 };
    for (int i = 0; childColMem.next(); i++) {
        assertEquals(expectedOriginalIndices[i], childColMem.getOriginalIndex());
        assertEquals(expectedIndexInColumn[i], childColMem.getIndexInColumn());
        assertEquals(expectedIndexInDataMemberships[i], childColMem.getIndexInDataMemberships());
    }
}
Also used : TreeEnsembleLearnerConfiguration(org.knime.base.node.mine.treeensemble2.node.learner.TreeEnsembleLearnerConfiguration) DefaultRowSample(org.knime.base.node.mine.treeensemble2.sample.row.DefaultRowSample) BitSet(java.util.BitSet) TestDataGenerator(org.knime.base.node.mine.treeensemble2.data.TestDataGenerator) TreeData(org.knime.base.node.mine.treeensemble2.data.TreeData) DefaultRowSample(org.knime.base.node.mine.treeensemble2.sample.row.DefaultRowSample) RowSample(org.knime.base.node.mine.treeensemble2.sample.row.RowSample) Test(org.junit.Test)

Aggregations

TreeData (org.knime.base.node.mine.treeensemble2.data.TreeData)6 RowSample (org.knime.base.node.mine.treeensemble2.sample.row.RowSample)6 BitSet (java.util.BitSet)5 TreeEnsembleLearnerConfiguration (org.knime.base.node.mine.treeensemble2.node.learner.TreeEnsembleLearnerConfiguration)5 RandomData (org.apache.commons.math.random.RandomData)3 TreeModelRegression (org.knime.base.node.mine.treeensemble2.model.TreeModelRegression)3 TreeNodeSignature (org.knime.base.node.mine.treeensemble2.model.TreeNodeSignature)3 Test (org.junit.Test)2 TestDataGenerator (org.knime.base.node.mine.treeensemble2.data.TestDataGenerator)2 TreeTargetNumericColumnData (org.knime.base.node.mine.treeensemble2.data.TreeTargetNumericColumnData)2 IDataIndexManager (org.knime.base.node.mine.treeensemble2.data.memberships.IDataIndexManager)2 RootDataMemberships (org.knime.base.node.mine.treeensemble2.data.memberships.RootDataMemberships)2 TreeLearnerRegression (org.knime.base.node.mine.treeensemble2.learner.TreeLearnerRegression)2 TreeNodeSignatureFactory (org.knime.base.node.mine.treeensemble2.learner.TreeNodeSignatureFactory)2 GradientBoostingLearnerConfiguration (org.knime.base.node.mine.treeensemble2.node.gradientboosting.learner.GradientBoostingLearnerConfiguration)2 ColumnSample (org.knime.base.node.mine.treeensemble2.sample.column.ColumnSample)2 DefaultRowSample (org.knime.base.node.mine.treeensemble2.sample.row.DefaultRowSample)2 ExecutionMonitor (org.knime.core.node.ExecutionMonitor)2 ArrayList (java.util.ArrayList)1 HashMap (java.util.HashMap)1