Search in sources :

Example 1 with TreeNodeSignatureFactory

use of org.knime.base.node.mine.treeensemble2.learner.TreeNodeSignatureFactory in project knime-core by knime.

the class SubsetColumnSampleStrategyTest method testGetColumnSampleForTreeNode.

/**
 * Tests the method {@link SubsetColumnSampleStrategy#getColumnSampleForTreeNode(org.knime.base.node.mine.treeensemble2.model.TreeNodeSignature)}
 * also tests {@link SubsetColumnSample} since the both always act in combination.
 *
 * @throws Exception
 */
@Test
public void testGetColumnSampleForTreeNode() throws Exception {
    final SubsetColumnSampleStrategy strategy = new SubsetColumnSampleStrategy(createTreeData(), RD, 5);
    TreeNodeSignatureFactory sigFac = createSignatureFactory();
    TreeNodeSignature rootSig = sigFac.getRootSignature();
    ColumnSample sample = strategy.getColumnSampleForTreeNode(rootSig);
    assertEquals("Wrong number of columns in sample.", 5, sample.getNumCols());
    int[] colIndices = sample.getColumnIndices();
    sample = strategy.getColumnSampleForTreeNode(sigFac.getChildSignatureFor(rootSig, (byte) 0));
    assertEquals("Wrong number of columns in sample.", 5, sample.getNumCols());
    assertArrayEquals(colIndices, sample.getColumnIndices());
    sample = strategy.getColumnSampleForTreeNode(sigFac.getChildSignatureFor(rootSig, (byte) 1));
    assertEquals("Wrong number of columns in sample.", 5, sample.getNumCols());
    assertArrayEquals(colIndices, sample.getColumnIndices());
}
Also used : TreeNodeSignature(org.knime.base.node.mine.treeensemble2.model.TreeNodeSignature) TreeNodeSignatureFactory(org.knime.base.node.mine.treeensemble2.learner.TreeNodeSignatureFactory) Test(org.junit.Test)

Example 2 with TreeNodeSignatureFactory

use of org.knime.base.node.mine.treeensemble2.learner.TreeNodeSignatureFactory in project knime-core by knime.

the class LKGradientBoostedTreesLearner method learn.

/**
 * {@inheritDoc}
 *
 * @throws ExecutionException
 * @throws InterruptedException
 */
@Override
public MultiClassGradientBoostedTreesModel learn(final ExecutionMonitor exec) throws CanceledExecutionException, InterruptedException, ExecutionException {
    final TreeData data = getData();
    final TreeTargetNominalColumnData target = (TreeTargetNominalColumnData) data.getTargetColumn();
    final NominalValueRepresentation[] classNomVals = target.getMetaData().getValues();
    final int numClasses = classNomVals.length;
    final String[] classLabels = new String[numClasses];
    final int nrModels = getConfig().getNrModels();
    final int nrRows = target.getNrRows();
    final TreeModelRegression[][] models = new TreeModelRegression[nrModels][numClasses];
    final ArrayList<ArrayList<Map<TreeNodeSignature, Double>>> coefficientMaps = new ArrayList<ArrayList<Map<TreeNodeSignature, Double>>>(nrModels);
    // variables for parallelization
    final ThreadPool tp = KNIMEConstants.GLOBAL_THREAD_POOL;
    final AtomicReference<Throwable> learnThrowableRef = new AtomicReference<Throwable>();
    final int procCount = 3 * Runtime.getRuntime().availableProcessors() / 2;
    exec.setMessage("Transforming problem");
    // transform the original k class classification problem into k regression problems
    final TreeData[] actual = new TreeData[numClasses];
    for (int i = 0; i < numClasses; i++) {
        final double[] newTarget = calculateNewTarget(target, i);
        actual[i] = createNumericDataFromArray(newTarget);
        classLabels[i] = classNomVals[i].getNominalValue();
    }
    final RandomData rd = getConfig().createRandomData();
    final double[][] previousFunctions = new double[numClasses][nrRows];
    TreeNodeSignatureFactory signatureFactory = null;
    final int maxLevels = getConfig().getMaxLevels();
    if (maxLevels < TreeEnsembleLearnerConfiguration.MAX_LEVEL_INFINITE) {
        int capacity = IntMath.pow(2, maxLevels - 1);
        signatureFactory = new TreeNodeSignatureFactory(capacity);
    } else {
        signatureFactory = new TreeNodeSignatureFactory();
    }
    exec.setMessage("Learn trees");
    for (int i = 0; i < nrModels; i++) {
        final Semaphore semaphore = new Semaphore(procCount);
        final ArrayList<Map<TreeNodeSignature, Double>> classCoefficientMaps = new ArrayList<Map<TreeNodeSignature, Double>>(numClasses);
        // prepare calculation of pseudoResiduals
        final double[][] probs = new double[numClasses][nrRows];
        for (int r = 0; r < nrRows; r++) {
            double sumExpF = 0;
            for (int j = 0; j < numClasses; j++) {
                sumExpF += Math.exp(previousFunctions[j][r]);
            }
            for (int j = 0; j < numClasses; j++) {
                probs[j][r] = Math.exp(previousFunctions[j][r]) / sumExpF;
            }
        }
        final Future<?>[] treeCoefficientMapPairs = new Future<?>[numClasses];
        for (int j = 0; j < numClasses; j++) {
            checkThrowable(learnThrowableRef);
            final RandomData rdSingle = TreeEnsembleLearnerConfiguration.createRandomData(rd.nextLong(Long.MIN_VALUE, Long.MAX_VALUE));
            final ExecutionMonitor subExec = exec.createSubProgress(0.0);
            semaphore.acquire();
            treeCoefficientMapPairs[j] = tp.enqueue(new TreeLearnerCallable(rdSingle, probs[j], actual[j], subExec, numClasses, previousFunctions[j], semaphore, learnThrowableRef, signatureFactory));
        }
        for (int j = 0; j < numClasses; j++) {
            checkThrowable(learnThrowableRef);
            semaphore.acquire();
            final Pair<TreeModelRegression, Map<TreeNodeSignature, Double>> pair = (Pair<TreeModelRegression, Map<TreeNodeSignature, Double>>) treeCoefficientMapPairs[j].get();
            models[i][j] = pair.getFirst();
            classCoefficientMaps.add(pair.getSecond());
            semaphore.release();
        }
        checkThrowable(learnThrowableRef);
        coefficientMaps.add(classCoefficientMaps);
        exec.setProgress((double) i / nrModels, "Finished level " + i + "/" + nrModels);
    }
    return MultiClassGradientBoostedTreesModel.createMultiClassGradientBoostedTreesModel(getConfig(), data.getMetaData(), models, data.getTreeType(), 0, numClasses, coefficientMaps, classLabels);
}
Also used : RandomData(org.apache.commons.math.random.RandomData) ArrayList(java.util.ArrayList) ThreadPool(org.knime.core.util.ThreadPool) NominalValueRepresentation(org.knime.base.node.mine.treeensemble2.data.NominalValueRepresentation) Semaphore(java.util.concurrent.Semaphore) TreeNodeSignature(org.knime.base.node.mine.treeensemble2.model.TreeNodeSignature) TreeModelRegression(org.knime.base.node.mine.treeensemble2.model.TreeModelRegression) ExecutionMonitor(org.knime.core.node.ExecutionMonitor) TreeNodeSignatureFactory(org.knime.base.node.mine.treeensemble2.learner.TreeNodeSignatureFactory) Pair(org.knime.core.util.Pair) AtomicReference(java.util.concurrent.atomic.AtomicReference) Future(java.util.concurrent.Future) TreeData(org.knime.base.node.mine.treeensemble2.data.TreeData) HashMap(java.util.HashMap) Map(java.util.Map) TreeTargetNominalColumnData(org.knime.base.node.mine.treeensemble2.data.TreeTargetNominalColumnData)

Example 3 with TreeNodeSignatureFactory

use of org.knime.base.node.mine.treeensemble2.learner.TreeNodeSignatureFactory in project knime-core by knime.

the class MGradientBoostedTreesLearner method learn.

/**
 * {@inheritDoc}
 */
@Override
public AbstractGradientBoostingModel learn(final ExecutionMonitor exec) throws CanceledExecutionException {
    final TreeData actualData = getData();
    final GradientBoostingLearnerConfiguration config = getConfig();
    final int nrModels = config.getNrModels();
    final TreeTargetNumericColumnData actualTarget = getTarget();
    final double initialValue = actualTarget.getMedian();
    final ArrayList<TreeModelRegression> models = new ArrayList<TreeModelRegression>(nrModels);
    final ArrayList<Map<TreeNodeSignature, Double>> coefficientMaps = new ArrayList<Map<TreeNodeSignature, Double>>(nrModels);
    final double[] previousPrediction = new double[actualTarget.getNrRows()];
    Arrays.fill(previousPrediction, initialValue);
    final RandomData rd = config.createRandomData();
    final double alpha = config.getAlpha();
    TreeNodeSignatureFactory signatureFactory = null;
    final int maxLevels = config.getMaxLevels();
    // this should be the default
    if (maxLevels < TreeEnsembleLearnerConfiguration.MAX_LEVEL_INFINITE) {
        final int capacity = IntMath.pow(2, maxLevels - 1);
        signatureFactory = new TreeNodeSignatureFactory(capacity);
    } else {
        signatureFactory = new TreeNodeSignatureFactory();
    }
    exec.setMessage("Learning model");
    TreeData residualData;
    for (int i = 0; i < nrModels; i++) {
        final double[] residuals = new double[actualTarget.getNrRows()];
        for (int j = 0; j < actualTarget.getNrRows(); j++) {
            residuals[j] = actualTarget.getValueFor(j) - previousPrediction[j];
        }
        final double quantile = calculateAlphaQuantile(residuals, alpha);
        final double[] gradients = new double[residuals.length];
        for (int j = 0; j < gradients.length; j++) {
            gradients[j] = Math.abs(residuals[j]) <= quantile ? residuals[j] : quantile * Math.signum(residuals[j]);
        }
        residualData = createResidualDataFromArray(gradients, actualData);
        final RandomData rdSingle = TreeEnsembleLearnerConfiguration.createRandomData(rd.nextLong(Long.MIN_VALUE, Long.MAX_VALUE));
        final RowSample rowSample = getRowSampler().createRowSample(rdSingle);
        final TreeLearnerRegression treeLearner = new TreeLearnerRegression(getConfig(), residualData, getIndexManager(), signatureFactory, rdSingle, rowSample);
        final TreeModelRegression tree = treeLearner.learnSingleTree(exec, rdSingle);
        final Map<TreeNodeSignature, Double> coefficientMap = calcCoefficientMap(residuals, quantile, tree);
        adaptPreviousPrediction(previousPrediction, tree, coefficientMap);
        models.add(tree);
        coefficientMaps.add(coefficientMap);
        exec.setProgress(((double) i) / nrModels, "Finished level " + i + "/" + nrModels);
    }
    return new GradientBoostedTreesModel(getConfig(), actualData.getMetaData(), models.toArray(new TreeModelRegression[models.size()]), actualData.getTreeType(), initialValue, coefficientMaps);
}
Also used : RandomData(org.apache.commons.math.random.RandomData) ArrayList(java.util.ArrayList) TreeTargetNumericColumnData(org.knime.base.node.mine.treeensemble2.data.TreeTargetNumericColumnData) GradientBoostedTreesModel(org.knime.base.node.mine.treeensemble2.model.GradientBoostedTreesModel) TreeNodeSignature(org.knime.base.node.mine.treeensemble2.model.TreeNodeSignature) TreeModelRegression(org.knime.base.node.mine.treeensemble2.model.TreeModelRegression) GradientBoostingLearnerConfiguration(org.knime.base.node.mine.treeensemble2.node.gradientboosting.learner.GradientBoostingLearnerConfiguration) TreeData(org.knime.base.node.mine.treeensemble2.data.TreeData) TreeLearnerRegression(org.knime.base.node.mine.treeensemble2.learner.TreeLearnerRegression) RowSample(org.knime.base.node.mine.treeensemble2.sample.row.RowSample) HashMap(java.util.HashMap) Map(java.util.Map) TreeNodeSignatureFactory(org.knime.base.node.mine.treeensemble2.learner.TreeNodeSignatureFactory)

Example 4 with TreeNodeSignatureFactory

use of org.knime.base.node.mine.treeensemble2.learner.TreeNodeSignatureFactory in project knime-core by knime.

the class AllColumnSampleStrategyTest method testGetColumnSampleForTreeNodeTest.

/**
 * Tests the method {@link AllColumnSampleStrategy#getColumnSampleForTreeNode(org.knime.base.node.mine.treeensemble2.model.TreeNodeSignature)}
 * This also tests the class {@link AllColumnSample}
 *
 * @throws Exception
 */
@Test
public void testGetColumnSampleForTreeNodeTest() throws Exception {
    final AllColumnSampleStrategy allColStrategy = new AllColumnSampleStrategy(createTreeData());
    final TreeNodeSignatureFactory sigFac = createSignatureFactory();
    TreeNodeSignature rootSig = sigFac.getRootSignature();
    ColumnSample sample = allColStrategy.getColumnSampleForTreeNode(rootSig);
    assertEquals("Wrong number of columns in sample.", TREE_DATA_SIZE, sample.getNumCols());
    int[] colIndices = new int[TREE_DATA_SIZE];
    for (int i = 0; i < colIndices.length; i++) {
        colIndices[i] = i;
    }
    assertArrayEquals(colIndices, sample.getColumnIndices());
    TreeNodeSignature childSig = sigFac.getChildSignatureFor(rootSig, (byte) 0);
    sample = allColStrategy.getColumnSampleForTreeNode(childSig);
    assertEquals("Wrong number of columns in sample.", TREE_DATA_SIZE, sample.getNumCols());
    assertArrayEquals(colIndices, sample.getColumnIndices());
}
Also used : TreeNodeSignature(org.knime.base.node.mine.treeensemble2.model.TreeNodeSignature) TreeNodeSignatureFactory(org.knime.base.node.mine.treeensemble2.learner.TreeNodeSignatureFactory) Test(org.junit.Test)

Example 5 with TreeNodeSignatureFactory

use of org.knime.base.node.mine.treeensemble2.learner.TreeNodeSignatureFactory in project knime-core by knime.

the class RFSubsetColumnSampleStrategyTest method testGetColumnSampleForTreeNode.

/**
 * Tests the method {@link RFSubsetColumnSampleStrategy#getColumnSampleForTreeNode(org.knime.base.node.mine.treeensemble2.model.TreeNodeSignature)}
 *
 * @throws Exception
 */
@Test
public void testGetColumnSampleForTreeNode() throws Exception {
    final RFSubsetColumnSampleStrategy strategy = new RFSubsetColumnSampleStrategy(createTreeData(), RD, 5);
    final TreeNodeSignatureFactory sigFac = createSignatureFactory();
    TreeNodeSignature rootSig = sigFac.getRootSignature();
    ColumnSample sample = strategy.getColumnSampleForTreeNode(rootSig);
    assertEquals("Wrong number of columns in sample.", 5, sample.getNumCols());
    int[] colIndices0 = sample.getColumnIndices();
    sample = strategy.getColumnSampleForTreeNode(sigFac.getChildSignatureFor(rootSig, (byte) 0));
    assertEquals("Wrong number of columns in sample.", 5, sample.getNumCols());
    int[] colIndices1 = sample.getColumnIndices();
    sample = strategy.getColumnSampleForTreeNode(sigFac.getChildSignatureFor(rootSig, (byte) 1));
    assertEquals("Wrong number of columns in sample.", 5, sample.getNumCols());
    int[] colIndices2 = sample.getColumnIndices();
    assertEquals("sample sizes differ.", colIndices0.length, colIndices1.length);
    assertEquals("sample sizes differ.", colIndices0.length, colIndices2.length);
    assertEquals("sample sizes differ.", colIndices1.length, colIndices2.length);
    boolean match = true;
    for (int i = 0; i < colIndices0.length; i++) {
        match = match && colIndices0[i] == colIndices1[i] && colIndices0[i] == colIndices2[i];
        if (!match) {
            break;
        }
    }
    assertFalse("It is very unlikely that we get 3 times the same column sample.", match);
}
Also used : TreeNodeSignature(org.knime.base.node.mine.treeensemble2.model.TreeNodeSignature) TreeNodeSignatureFactory(org.knime.base.node.mine.treeensemble2.learner.TreeNodeSignatureFactory) Test(org.junit.Test)

Aggregations

TreeNodeSignatureFactory (org.knime.base.node.mine.treeensemble2.learner.TreeNodeSignatureFactory)6 TreeNodeSignature (org.knime.base.node.mine.treeensemble2.model.TreeNodeSignature)5 RandomData (org.apache.commons.math.random.RandomData)3 Test (org.junit.Test)3 TreeData (org.knime.base.node.mine.treeensemble2.data.TreeData)3 TreeModelRegression (org.knime.base.node.mine.treeensemble2.model.TreeModelRegression)3 ArrayList (java.util.ArrayList)2 HashMap (java.util.HashMap)2 Map (java.util.Map)2 TreeLearnerRegression (org.knime.base.node.mine.treeensemble2.learner.TreeLearnerRegression)2 RowSample (org.knime.base.node.mine.treeensemble2.sample.row.RowSample)2 ExecutionMonitor (org.knime.core.node.ExecutionMonitor)2 Future (java.util.concurrent.Future)1 Semaphore (java.util.concurrent.Semaphore)1 AtomicReference (java.util.concurrent.atomic.AtomicReference)1 NominalValueRepresentation (org.knime.base.node.mine.treeensemble2.data.NominalValueRepresentation)1 TreeDataCreator (org.knime.base.node.mine.treeensemble2.data.TreeDataCreator)1 TreeTargetNominalColumnData (org.knime.base.node.mine.treeensemble2.data.TreeTargetNominalColumnData)1 TreeTargetNumericColumnData (org.knime.base.node.mine.treeensemble2.data.TreeTargetNumericColumnData)1 BitVectorDataIndexManager (org.knime.base.node.mine.treeensemble2.data.memberships.BitVectorDataIndexManager)1