Search in sources :

Example 26 with TreeData

use of org.knime.base.node.mine.treeensemble2.data.TreeData in project knime-core by knime.

the class RootDescendantDataMembershipsTest method testGetColumnMemberships.

@Test
public void testGetColumnMemberships() {
    TreeEnsembleLearnerConfiguration config = new TreeEnsembleLearnerConfiguration(false);
    TestDataGenerator dataGen = new TestDataGenerator(config);
    TreeData data = dataGen.createTennisData();
    DefaultDataIndexManager indexManager = new DefaultDataIndexManager(data);
    int nrRows = data.getNrRows();
    RowSample rowSample = new DefaultRowSample(nrRows);
    RootDataMemberships rootMemberships = new RootDataMemberships(rowSample, data, indexManager);
    ColumnMemberships rootColMem = rootMemberships.getColumnMemberships(0);
    assertThat(rootColMem, instanceOf(IntArrayColumnMemberships.class));
    assertEquals(nrRows, rootColMem.size());
    int[] expectedOriginalIndices = new int[] { 0, 1, 7, 8, 10, 2, 6, 11, 12, 3, 4, 5, 9, 13 };
    for (int i = 0; rootColMem.next(); i++) {
        // in this case originalIndex and indexInDataMemberships are the same
        assertEquals(expectedOriginalIndices[i], rootColMem.getIndexInDataMemberships());
        assertEquals(expectedOriginalIndices[i], rootColMem.getIndexInDataMemberships());
        assertEquals(i, rootColMem.getIndexInColumn());
    }
    BitSet lastHalf = new BitSet(nrRows);
    lastHalf.set(nrRows / 2, nrRows);
    DataMemberships lastHalfChild = rootMemberships.createChildMemberships(lastHalf);
    ColumnMemberships childColMem = lastHalfChild.getColumnMemberships(0);
    assertThat(childColMem, instanceOf(DescendantColumnMemberships.class));
    assertEquals(nrRows / 2, childColMem.size());
    expectedOriginalIndices = new int[] { 7, 8, 10, 11, 12, 9, 13 };
    int[] expectedIndexInColumn = new int[] { 2, 3, 4, 7, 8, 12, 13 };
    int[] expectedIndexInDataMemberships = new int[] { 7, 8, 10, 11, 12, 9, 13 };
    for (int i = 0; childColMem.next(); i++) {
        assertEquals(expectedOriginalIndices[i], childColMem.getOriginalIndex());
        assertEquals(expectedIndexInColumn[i], childColMem.getIndexInColumn());
        assertEquals(expectedIndexInDataMemberships[i], childColMem.getIndexInDataMemberships());
    }
}
Also used : TreeEnsembleLearnerConfiguration(org.knime.base.node.mine.treeensemble2.node.learner.TreeEnsembleLearnerConfiguration) DefaultRowSample(org.knime.base.node.mine.treeensemble2.sample.row.DefaultRowSample) BitSet(java.util.BitSet) TestDataGenerator(org.knime.base.node.mine.treeensemble2.data.TestDataGenerator) TreeData(org.knime.base.node.mine.treeensemble2.data.TreeData) DefaultRowSample(org.knime.base.node.mine.treeensemble2.sample.row.DefaultRowSample) RowSample(org.knime.base.node.mine.treeensemble2.sample.row.RowSample) Test(org.junit.Test)

Example 27 with TreeData

use of org.knime.base.node.mine.treeensemble2.data.TreeData in project knime-core by knime.

the class GradientBoostingClassificationLearnerNodeModel method execute.

/**
 * {@inheritDoc}
 */
@Override
protected PortObject[] execute(final PortObject[] inData, final ExecutionContext exec) throws Exception {
    BufferedDataTable t = (BufferedDataTable) inData[0];
    DataTableSpec spec = t.getDataTableSpec();
    final FilterLearnColumnRearranger learnRearranger = m_configuration.filterLearnColumns(spec);
    String warn = learnRearranger.getWarning();
    BufferedDataTable learnTable = exec.createColumnRearrangeTable(t, learnRearranger, exec.createSubProgress(0.0));
    DataTableSpec learnSpec = learnTable.getDataTableSpec();
    TreeEnsembleModelPortObjectSpec ensembleSpec = m_configuration.createPortObjectSpec(learnSpec);
    ExecutionMonitor readInExec = exec.createSubProgress(0.1);
    ExecutionMonitor learnExec = exec.createSubProgress(0.8);
    TreeDataCreator dataCreator = new TreeDataCreator(m_configuration, learnSpec, learnTable.getRowCount());
    exec.setProgress("Reading data into memory");
    TreeData data = dataCreator.readData(learnTable, m_configuration, readInExec);
    // m_hiliteRowSample = dataCreator.getDataRowsForHilite();
    // m_viewMessage = dataCreator.getViewMessage();
    String dataCreationWarning = dataCreator.getAndClearWarningMessage();
    if (dataCreationWarning != null) {
        if (warn == null) {
            warn = dataCreationWarning;
        } else {
            warn = warn + "\n" + dataCreationWarning;
        }
    }
    readInExec.setProgress(1.0);
    exec.setMessage("Learning trees");
    AbstractGradientBoostingLearner learner = new LKGradientBoostedTreesLearner(m_configuration, data);
    AbstractGradientBoostingModel model;
    // m_configuration.setMissingValueHandling(MissingValueHandling.XGBoost);
    // try {
    model = learner.learn(learnExec);
    // } catch (ExecutionException e) {
    // Throwable cause = e.getCause();
    // if (cause instanceof Exception) {
    // throw (Exception)cause;
    // }
    // throw e;
    // }
    GradientBoostingModelPortObject modelPortObject = new GradientBoostingModelPortObject(ensembleSpec, model);
    learnExec.setProgress(1.0);
    if (warn != null) {
        setWarningMessage(warn);
    }
    return new PortObject[] { modelPortObject };
}
Also used : DataTableSpec(org.knime.core.data.DataTableSpec) GradientBoostingModelPortObject(org.knime.base.node.mine.treeensemble2.model.GradientBoostingModelPortObject) TreeEnsembleModelPortObjectSpec(org.knime.base.node.mine.treeensemble2.model.TreeEnsembleModelPortObjectSpec) LKGradientBoostedTreesLearner(org.knime.base.node.mine.treeensemble2.learner.gradientboosting.LKGradientBoostedTreesLearner) AbstractGradientBoostingModel(org.knime.base.node.mine.treeensemble2.model.AbstractGradientBoostingModel) BufferedDataTable(org.knime.core.node.BufferedDataTable) FilterLearnColumnRearranger(org.knime.base.node.mine.treeensemble2.node.learner.TreeEnsembleLearnerConfiguration.FilterLearnColumnRearranger) TreeData(org.knime.base.node.mine.treeensemble2.data.TreeData) ExecutionMonitor(org.knime.core.node.ExecutionMonitor) TreeDataCreator(org.knime.base.node.mine.treeensemble2.data.TreeDataCreator) PortObject(org.knime.core.node.port.PortObject) GradientBoostingModelPortObject(org.knime.base.node.mine.treeensemble2.model.GradientBoostingModelPortObject) AbstractGradientBoostingLearner(org.knime.base.node.mine.treeensemble2.learner.gradientboosting.AbstractGradientBoostingLearner)

Example 28 with TreeData

use of org.knime.base.node.mine.treeensemble2.data.TreeData in project knime-core by knime.

the class GradientBoostingRegressionLearnerNodeModel method execute.

/**
 * {@inheritDoc}
 */
@Override
protected PortObject[] execute(final PortObject[] inData, final ExecutionContext exec) throws Exception {
    BufferedDataTable t = (BufferedDataTable) inData[0];
    DataTableSpec spec = t.getDataTableSpec();
    final FilterLearnColumnRearranger learnRearranger = m_configuration.filterLearnColumns(spec);
    String warn = learnRearranger.getWarning();
    BufferedDataTable learnTable = exec.createColumnRearrangeTable(t, learnRearranger, exec.createSubProgress(0.0));
    DataTableSpec learnSpec = learnTable.getDataTableSpec();
    TreeEnsembleModelPortObjectSpec ensembleSpec = m_configuration.createPortObjectSpec(learnSpec);
    ExecutionMonitor readInExec = exec.createSubProgress(0.1);
    ExecutionMonitor learnExec = exec.createSubProgress(0.8);
    ExecutionMonitor outOfBagExec = exec.createSubProgress(0.1);
    TreeDataCreator dataCreator = new TreeDataCreator(m_configuration, learnSpec, learnTable.getRowCount());
    exec.setProgress("Reading data into memory");
    TreeData data = dataCreator.readData(learnTable, m_configuration, readInExec);
    // m_hiliteRowSample = dataCreator.getDataRowsForHilite();
    // m_viewMessage = dataCreator.getViewMessage();
    String dataCreationWarning = dataCreator.getAndClearWarningMessage();
    if (dataCreationWarning != null) {
        if (warn == null) {
            warn = dataCreationWarning;
        } else {
            warn = warn + "\n" + dataCreationWarning;
        }
    }
    readInExec.setProgress(1.0);
    exec.setMessage("Learning trees");
    AbstractGradientBoostingLearner learner = new MGradientBoostedTreesLearner(m_configuration, data);
    AbstractGradientBoostingModel model;
    // try {
    model = learner.learn(learnExec);
    // } catch (ExecutionException e) {
    // Throwable cause = e.getCause();
    // if (cause instanceof Exception) {
    // throw (Exception)cause;
    // }
    // throw e;
    // }
    GradientBoostingModelPortObject modelPortObject = new GradientBoostingModelPortObject(ensembleSpec, model);
    learnExec.setProgress(1.0);
    if (warn != null) {
        setWarningMessage(warn);
    }
    return new PortObject[] { modelPortObject };
}
Also used : DataTableSpec(org.knime.core.data.DataTableSpec) GradientBoostingModelPortObject(org.knime.base.node.mine.treeensemble2.model.GradientBoostingModelPortObject) TreeEnsembleModelPortObjectSpec(org.knime.base.node.mine.treeensemble2.model.TreeEnsembleModelPortObjectSpec) MGradientBoostedTreesLearner(org.knime.base.node.mine.treeensemble2.learner.gradientboosting.MGradientBoostedTreesLearner) AbstractGradientBoostingModel(org.knime.base.node.mine.treeensemble2.model.AbstractGradientBoostingModel) BufferedDataTable(org.knime.core.node.BufferedDataTable) FilterLearnColumnRearranger(org.knime.base.node.mine.treeensemble2.node.learner.TreeEnsembleLearnerConfiguration.FilterLearnColumnRearranger) TreeData(org.knime.base.node.mine.treeensemble2.data.TreeData) ExecutionMonitor(org.knime.core.node.ExecutionMonitor) TreeDataCreator(org.knime.base.node.mine.treeensemble2.data.TreeDataCreator) PortObject(org.knime.core.node.port.PortObject) GradientBoostingModelPortObject(org.knime.base.node.mine.treeensemble2.model.GradientBoostingModelPortObject) AbstractGradientBoostingLearner(org.knime.base.node.mine.treeensemble2.learner.gradientboosting.AbstractGradientBoostingLearner)

Example 29 with TreeData

use of org.knime.base.node.mine.treeensemble2.data.TreeData in project knime-core by knime.

the class TreeEnsembleRegressionLearnerNodeModel method execute.

/**
 * {@inheritDoc}
 */
@Override
protected PortObject[] execute(final PortObject[] inObjects, final ExecutionContext exec) throws Exception {
    BufferedDataTable t = (BufferedDataTable) inObjects[0];
    DataTableSpec spec = t.getDataTableSpec();
    final FilterLearnColumnRearranger learnRearranger = m_configuration.filterLearnColumns(spec);
    String warn = learnRearranger.getWarning();
    BufferedDataTable learnTable = exec.createColumnRearrangeTable(t, learnRearranger, exec.createSubProgress(0.0));
    DataTableSpec learnSpec = learnTable.getDataTableSpec();
    TreeEnsembleModelPortObjectSpec ensembleSpec = m_configuration.createPortObjectSpec(learnSpec);
    ExecutionMonitor readInExec = exec.createSubProgress(0.1);
    ExecutionMonitor learnExec = exec.createSubProgress(0.8);
    ExecutionMonitor outOfBagExec = exec.createSubProgress(0.1);
    TreeDataCreator dataCreator = new TreeDataCreator(m_configuration, learnSpec, learnTable.getRowCount());
    exec.setProgress("Reading data into memory");
    TreeData data = dataCreator.readData(learnTable, m_configuration, readInExec);
    m_hiliteRowSample = dataCreator.getDataRowsForHilite();
    m_viewMessage = dataCreator.getViewMessage();
    String dataCreationWarning = dataCreator.getAndClearWarningMessage();
    if (dataCreationWarning != null) {
        if (warn == null) {
            warn = dataCreationWarning;
        } else {
            warn = warn + "\n" + dataCreationWarning;
        }
    }
    readInExec.setProgress(1.0);
    exec.setMessage("Learning trees");
    TreeEnsembleLearner learner = new TreeEnsembleLearner(m_configuration, data);
    TreeEnsembleModel model;
    try {
        model = learner.learnEnsemble(learnExec);
    } catch (ExecutionException e) {
        Throwable cause = e.getCause();
        if (cause instanceof Exception) {
            throw (Exception) cause;
        }
        throw e;
    }
    TreeEnsembleModelPortObject modelPortObject = TreeEnsembleModelPortObject.createPortObject(ensembleSpec, model, exec.createFileStore(UUID.randomUUID().toString() + ""));
    learnExec.setProgress(1.0);
    exec.setMessage("Out of bag prediction");
    TreeEnsemblePredictor outOfBagPredictor = createOutOfBagPredictor(ensembleSpec, modelPortObject, spec);
    outOfBagPredictor.setOutofBagFilter(learner.getRowSamples(), data.getTargetColumn());
    ColumnRearranger outOfBagRearranger = outOfBagPredictor.getPredictionRearranger();
    BufferedDataTable outOfBagTable = exec.createColumnRearrangeTable(t, outOfBagRearranger, outOfBagExec);
    BufferedDataTable colStatsTable = learner.createColumnStatisticTable(exec.createSubExecutionContext(0.0));
    m_ensembleModelPortObject = modelPortObject;
    if (warn != null) {
        setWarningMessage(warn);
    }
    return new PortObject[] { outOfBagTable, colStatsTable, modelPortObject };
}
Also used : DataTableSpec(org.knime.core.data.DataTableSpec) TreeEnsembleModel(org.knime.base.node.mine.treeensemble2.model.TreeEnsembleModel) TreeEnsembleModelPortObjectSpec(org.knime.base.node.mine.treeensemble2.model.TreeEnsembleModelPortObjectSpec) TreeEnsembleLearner(org.knime.base.node.mine.treeensemble2.learner.TreeEnsembleLearner) InvalidSettingsException(org.knime.core.node.InvalidSettingsException) CanceledExecutionException(org.knime.core.node.CanceledExecutionException) IOException(java.io.IOException) ExecutionException(java.util.concurrent.ExecutionException) TreeEnsembleModelPortObject(org.knime.base.node.mine.treeensemble2.model.TreeEnsembleModelPortObject) ColumnRearranger(org.knime.core.data.container.ColumnRearranger) FilterLearnColumnRearranger(org.knime.base.node.mine.treeensemble2.node.learner.TreeEnsembleLearnerConfiguration.FilterLearnColumnRearranger) BufferedDataTable(org.knime.core.node.BufferedDataTable) FilterLearnColumnRearranger(org.knime.base.node.mine.treeensemble2.node.learner.TreeEnsembleLearnerConfiguration.FilterLearnColumnRearranger) TreeData(org.knime.base.node.mine.treeensemble2.data.TreeData) ExecutionMonitor(org.knime.core.node.ExecutionMonitor) CanceledExecutionException(org.knime.core.node.CanceledExecutionException) ExecutionException(java.util.concurrent.ExecutionException) TreeEnsemblePredictor(org.knime.base.node.mine.treeensemble2.node.predictor.TreeEnsemblePredictor) TreeDataCreator(org.knime.base.node.mine.treeensemble2.data.TreeDataCreator) TreeEnsembleModelPortObject(org.knime.base.node.mine.treeensemble2.model.TreeEnsembleModelPortObject) PortObject(org.knime.core.node.port.PortObject)

Example 30 with TreeData

use of org.knime.base.node.mine.treeensemble2.data.TreeData in project knime-core by knime.

the class TreeNominalColumnDataTest method testCalcBestSplitClassificationBinaryTwoClass.

/**
 * Tests the method
 * {@link TreeNominalColumnData#calcBestSplitClassification(DataMemberships, ClassificationPriors, TreeTargetNominalColumnData, RandomData)}
 * in case of a two class problem.
 *
 * @throws Exception
 */
@Test
public void testCalcBestSplitClassificationBinaryTwoClass() throws Exception {
    TreeEnsembleLearnerConfiguration config = createConfig(false);
    config.setMissingValueHandling(MissingValueHandling.Surrogate);
    Pair<TreeNominalColumnData, TreeTargetNominalColumnData> twoClassTennisData = twoClassTennisData(config);
    TreeNominalColumnData columnData = twoClassTennisData.getFirst();
    TreeTargetNominalColumnData targetData = twoClassTennisData.getSecond();
    TreeData twoClassTennisTreeData = twoClassTennisTreeData(config);
    IDataIndexManager indexManager = new DefaultDataIndexManager(twoClassTennisTreeData);
    assertEquals(SplitCriterion.Gini, config.getSplitCriterion());
    double[] rowWeights = new double[TWO_CLASS_INDICES.length];
    Arrays.fill(rowWeights, 1.0);
    // DataMemberships dataMemberships = TestDataGenerator.createMockDataMemberships(TWO_CLASS_INDICES.length);
    DataMemberships dataMemberships = new RootDataMemberships(rowWeights, twoClassTennisTreeData, indexManager);
    ClassificationPriors priors = targetData.getDistribution(rowWeights, config);
    SplitCandidate splitCandidate = columnData.calcBestSplitClassification(dataMemberships, priors, targetData, null);
    assertNotNull(splitCandidate);
    assertThat(splitCandidate, instanceOf(NominalBinarySplitCandidate.class));
    assertTrue(splitCandidate.canColumnBeSplitFurther());
    // manually via open office calc
    assertEquals(0.1371428, splitCandidate.getGainValue(), 0.00001);
    NominalBinarySplitCandidate binSplitCandidate = (NominalBinarySplitCandidate) splitCandidate;
    TreeNodeNominalBinaryCondition[] childConditions = binSplitCandidate.getChildConditions();
    assertEquals(2, childConditions.length);
    assertArrayEquals(new String[] { "R" }, childConditions[0].getValues());
    assertArrayEquals(new String[] { "R" }, childConditions[1].getValues());
    assertEquals(SetLogic.IS_NOT_IN, childConditions[0].getSetLogic());
    assertEquals(SetLogic.IS_IN, childConditions[1].getSetLogic());
    assertFalse(childConditions[0].acceptsMissings());
    assertFalse(childConditions[1].acceptsMissings());
}
Also used : TreeEnsembleLearnerConfiguration(org.knime.base.node.mine.treeensemble2.node.learner.TreeEnsembleLearnerConfiguration) RootDataMemberships(org.knime.base.node.mine.treeensemble2.data.memberships.RootDataMemberships) IDataIndexManager(org.knime.base.node.mine.treeensemble2.data.memberships.IDataIndexManager) NominalMultiwaySplitCandidate(org.knime.base.node.mine.treeensemble2.learner.NominalMultiwaySplitCandidate) NominalBinarySplitCandidate(org.knime.base.node.mine.treeensemble2.learner.NominalBinarySplitCandidate) SplitCandidate(org.knime.base.node.mine.treeensemble2.learner.SplitCandidate) DefaultDataIndexManager(org.knime.base.node.mine.treeensemble2.data.memberships.DefaultDataIndexManager) DataMemberships(org.knime.base.node.mine.treeensemble2.data.memberships.DataMemberships) RootDataMemberships(org.knime.base.node.mine.treeensemble2.data.memberships.RootDataMemberships) TreeNodeNominalBinaryCondition(org.knime.base.node.mine.treeensemble2.model.TreeNodeNominalBinaryCondition) NominalBinarySplitCandidate(org.knime.base.node.mine.treeensemble2.learner.NominalBinarySplitCandidate) Test(org.junit.Test)

Aggregations

TreeData (org.knime.base.node.mine.treeensemble2.data.TreeData)27 TreeEnsembleLearnerConfiguration (org.knime.base.node.mine.treeensemble2.node.learner.TreeEnsembleLearnerConfiguration)25 Test (org.junit.Test)18 DataMemberships (org.knime.base.node.mine.treeensemble2.data.memberships.DataMemberships)18 RootDataMemberships (org.knime.base.node.mine.treeensemble2.data.memberships.RootDataMemberships)18 DefaultDataIndexManager (org.knime.base.node.mine.treeensemble2.data.memberships.DefaultDataIndexManager)15 IDataIndexManager (org.knime.base.node.mine.treeensemble2.data.memberships.IDataIndexManager)15 RandomData (org.apache.commons.math.random.RandomData)14 TreeAttributeColumnData (org.knime.base.node.mine.treeensemble2.data.TreeAttributeColumnData)12 SplitCandidate (org.knime.base.node.mine.treeensemble2.learner.SplitCandidate)12 BitSet (java.util.BitSet)11 TreeTargetNumericColumnData (org.knime.base.node.mine.treeensemble2.data.TreeTargetNumericColumnData)8 TreeNodeSignature (org.knime.base.node.mine.treeensemble2.model.TreeNodeSignature)8 ExecutionMonitor (org.knime.core.node.ExecutionMonitor)8 TreeDataCreator (org.knime.base.node.mine.treeensemble2.data.TreeDataCreator)7 TreeTargetNominalColumnData (org.knime.base.node.mine.treeensemble2.data.TreeTargetNominalColumnData)7 NominalBinarySplitCandidate (org.knime.base.node.mine.treeensemble2.learner.NominalBinarySplitCandidate)7 NominalMultiwaySplitCandidate (org.knime.base.node.mine.treeensemble2.learner.NominalMultiwaySplitCandidate)7 FilterLearnColumnRearranger (org.knime.base.node.mine.treeensemble2.node.learner.TreeEnsembleLearnerConfiguration.FilterLearnColumnRearranger)7 TreeEnsembleModelPortObjectSpec (org.knime.base.node.mine.treeensemble2.model.TreeEnsembleModelPortObjectSpec)6