Search in sources :

Example 1 with AbstractTreeModel

use of org.knime.base.node.mine.treeensemble2.model.AbstractTreeModel in project knime-core by knime.

the class RandomForestDistance method computeDistance.

/**
 * {@inheritDoc}
 */
@Override
public double computeDistance(final DataRow row1, final DataRow row2) throws DistanceMeasurementException {
    List<Integer> filterIndicesList = getColumnIndices();
    int[] filterIndices = new int[filterIndicesList.size()];
    int i = 0;
    for (Integer index : filterIndicesList) {
        filterIndices[i++] = index;
    }
    final DataRow filterRow1 = new FilterColumnRow(row1, filterIndices);
    final DataRow filterRow2 = new FilterColumnRow(row2, filterIndices);
    final PredictorRecord record1 = m_ensembleModel.createPredictorRecord(filterRow1, m_learnTableSpec);
    final PredictorRecord record2 = m_ensembleModel.createPredictorRecord(filterRow2, m_learnTableSpec);
    final int nrModels = m_ensembleModel.getNrModels();
    double proximity = 0.0;
    for (int t = 0; t < nrModels; t++) {
        AbstractTreeModel<?> tree = m_ensembleModel.getTreeModel(t);
        AbstractTreeNode leaf1 = tree.findMatchingNode(record1);
        AbstractTreeNode leaf2 = tree.findMatchingNode(record2);
        if (leaf1.getSignature().equals(leaf2.getSignature())) {
            proximity += 1.0;
        }
    }
    proximity /= nrModels;
    // to get a distance measure, we have to subtract the proximity from 1
    return 1 - proximity;
}
Also used : PredictorRecord(org.knime.base.node.mine.treeensemble2.data.PredictorRecord) AbstractTreeNode(org.knime.base.node.mine.treeensemble2.model.AbstractTreeNode) DataRow(org.knime.core.data.DataRow) FilterColumnRow(org.knime.base.data.filter.column.FilterColumnRow)

Example 2 with AbstractTreeModel

use of org.knime.base.node.mine.treeensemble2.model.AbstractTreeModel in project knime-core by knime.

the class TreeEnsembleModel method load.

public static TreeEnsembleModel load(final InputStream in) throws IOException {
    // wrapping the argument (zip input) stream in a buffered stream
    // reduces read operation from, e.g. 42s to 2s
    TreeModelDataInputStream input = new TreeModelDataInputStream(new BufferedInputStream(new NonClosableInputStream(in)));
    int version = input.readInt();
    if (version > 20160114) {
        throw new IOException("Tree Ensemble version " + version + " not supported");
    }
    byte ensembleType;
    if (version == 20160114) {
        ensembleType = input.readByte();
    } else {
        ensembleType = 'r';
    }
    TreeType type = TreeType.load(input);
    TreeMetaData metaData = TreeMetaData.load(input);
    int nrModels = input.readInt();
    boolean containsClassDistribution;
    if (version == 20121019) {
        containsClassDistribution = true;
    } else {
        containsClassDistribution = input.readBoolean();
    }
    input.setContainsClassDistribution(containsClassDistribution);
    AbstractTreeModel[] models = new AbstractTreeModel[nrModels];
    boolean isRegression = metaData.isRegression();
    if (ensembleType != 'r') {
        isRegression = true;
    }
    final TreeBuildingInterner treeBuildingInterner = new TreeBuildingInterner();
    for (int i = 0; i < nrModels; i++) {
        AbstractTreeModel singleModel;
        try {
            singleModel = isRegression ? TreeModelRegression.load(input, metaData, treeBuildingInterner) : TreeModelClassification.load(input, metaData, treeBuildingInterner);
            if (input.readByte() != 0) {
                throw new IOException("Model not terminated by 0 byte");
            }
        } catch (IOException e) {
            throw new IOException("Can't read tree model " + (i + 1) + "/" + nrModels + ": " + e.getMessage(), e);
        }
        models[i] = singleModel;
    }
    TreeEnsembleModel result;
    switch(ensembleType) {
        case 'r':
            result = new TreeEnsembleModel(metaData, models, type, containsClassDistribution);
            break;
        case 'g':
            result = new GradientBoostingModel(metaData, models, type, containsClassDistribution);
            break;
        case 't':
            result = new GradientBoostedTreesModel(metaData, models, type, containsClassDistribution);
            break;
        case 'm':
            result = new MultiClassGradientBoostedTreesModel(metaData, models, type, containsClassDistribution);
            break;
        default:
            throw new IllegalStateException("Unknown ensemble type: '" + (char) ensembleType + "'");
    }
    result.loadData(input);
    // does not close the method argument stream!!
    input.close();
    return result;
}
Also used : IOException(java.io.IOException) BufferedInputStream(java.io.BufferedInputStream) NonClosableInputStream(org.knime.core.data.util.NonClosableInputStream) TreeMetaData(org.knime.base.node.mine.treeensemble2.data.TreeMetaData)

Example 3 with AbstractTreeModel

use of org.knime.base.node.mine.treeensemble2.model.AbstractTreeModel in project knime-core by knime.

the class TreeModelImporter method importFromPMML.

/**
 * Imports an {@link AbstractTreeModel} from PMML.
 *
 * @param treeModel PMML tree model to import
 * @return a {@link AbstractTreeModel} initialized from <b>treeModel</b>
 */
public M importFromPMML(final TreeModel treeModel) {
    Node rootNode = treeModel.getNode();
    N root = createNodeFromPMML(rootNode, m_signatureFactory.getRootSignature());
    return m_treeFactory.createTree(root);
}
Also used : AbstractTreeNode(org.knime.base.node.mine.treeensemble2.model.AbstractTreeNode) Node(org.dmg.pmml.NodeDocument.Node)

Example 4 with AbstractTreeModel

use of org.knime.base.node.mine.treeensemble2.model.AbstractTreeModel in project knime-core by knime.

the class Proximity method calcProximities.

public static ProximityMatrix calcProximities(final BufferedDataTable[] tables, final TreeEnsembleModelPortObject modelPortObject, final ExecutionContext exec) throws InvalidSettingsException, InterruptedException, ExecutionException, CanceledExecutionException {
    ProximityMatrix proximityMatrix = null;
    boolean optionalTable = false;
    switch(tables.length) {
        case 1:
            if (tables[0].size() <= 65500) {
                proximityMatrix = new SingleTableProximityMatrix(tables[0]);
            } else {
                // this is unfortunate and we should maybe think of a different solution
                proximityMatrix = new TwoTablesProximityMatrix(tables[0], tables[0]);
            }
            break;
        case 2:
            optionalTable = true;
            proximityMatrix = new TwoTablesProximityMatrix(tables[0], tables[1]);
            break;
        default:
            throw new IllegalArgumentException("Currently only up to two tables are supported.");
    }
    final TreeEnsembleModelPortObjectSpec modelSpec = modelPortObject.getSpec();
    final TreeEnsembleModel ensembleModel = modelPortObject.getEnsembleModel();
    int[][] learnColIndicesInTables = null;
    if (optionalTable) {
        learnColIndicesInTables = new int[][] { modelSpec.calculateFilterIndices(tables[0].getDataTableSpec()), modelSpec.calculateFilterIndices(tables[1].getDataTableSpec()) };
    } else {
        learnColIndicesInTables = new int[][] { modelSpec.calculateFilterIndices(tables[0].getDataTableSpec()) };
    }
    final ThreadPool tp = KNIMEConstants.GLOBAL_THREAD_POOL;
    final int procCount = 3 * Runtime.getRuntime().availableProcessors() / 2;
    final Semaphore semaphore = new Semaphore(procCount);
    final AtomicReference<Throwable> proxThrowableRef = new AtomicReference<Throwable>();
    final int nrTrees = ensembleModel.getNrModels();
    final Future<?>[] calcFutures = new Future<?>[nrTrees];
    exec.setProgress(0, "Starting proximity calculation per tree.");
    for (int i = 0; i < nrTrees; i++) {
        semaphore.acquire();
        finishedTree(i, exec, nrTrees);
        checkThrowable(proxThrowableRef);
        AbstractTreeModel treeModel = ensembleModel.getTreeModel(i);
        ExecutionMonitor subExec = exec.createSubProgress(0.0);
        if (optionalTable) {
            calcFutures[i] = tp.enqueue(new TwoTablesProximityCalcRunnable(proximityMatrix, tables, learnColIndicesInTables, treeModel, modelPortObject, semaphore, proxThrowableRef, subExec));
        } else {
            calcFutures[i] = tp.enqueue(new SingleTableProximityCalcRunnable(proximityMatrix, tables, learnColIndicesInTables, treeModel, modelPortObject, semaphore, proxThrowableRef, subExec));
        }
    }
    for (int i = 0; i < procCount; i++) {
        semaphore.acquire();
        finishedTree(nrTrees - procCount + i, exec, nrTrees);
    }
    for (Future<?> future : calcFutures) {
        try {
            future.get();
        } catch (Exception e) {
            proxThrowableRef.compareAndSet(null, e);
        }
    }
    checkThrowable(proxThrowableRef);
    proximityMatrix.normalize(1.0 / nrTrees);
    return proximityMatrix;
}
Also used : TreeEnsembleModel(org.knime.base.node.mine.treeensemble2.model.TreeEnsembleModel) TreeEnsembleModelPortObjectSpec(org.knime.base.node.mine.treeensemble2.model.TreeEnsembleModelPortObjectSpec) ThreadPool(org.knime.core.util.ThreadPool) AtomicReference(java.util.concurrent.atomic.AtomicReference) AbstractTreeModel(org.knime.base.node.mine.treeensemble2.model.AbstractTreeModel) Semaphore(java.util.concurrent.Semaphore) InvalidSettingsException(org.knime.core.node.InvalidSettingsException) CanceledExecutionException(org.knime.core.node.CanceledExecutionException) ExecutionException(java.util.concurrent.ExecutionException) Future(java.util.concurrent.Future) ExecutionMonitor(org.knime.core.node.ExecutionMonitor)

Example 5 with AbstractTreeModel

use of org.knime.base.node.mine.treeensemble2.model.AbstractTreeModel in project knime-core by knime.

the class TreeEnsembleLearner method createColumnStatisticTable.

public BufferedDataTable createColumnStatisticTable(final ExecutionContext exec) throws CanceledExecutionException {
    BufferedDataContainer c = exec.createDataContainer(getColumnStatisticTableSpec());
    final int nrModels = m_ensembleModel.getNrModels();
    final TreeAttributeColumnData[] columns = m_data.getColumns();
    final int nrAttributes = columns.length;
    int[][] columnOnLevelCounts = new int[REPORT_LEVEL][nrAttributes];
    int[][] columnInLevelSampleCounts = new int[REPORT_LEVEL][nrAttributes];
    for (int i = 0; i < nrModels; i++) {
        final AbstractTreeModel<?> treeModel = m_ensembleModel.getTreeModel(i);
        for (int level = 0; level < REPORT_LEVEL; level++) {
            for (AbstractTreeNode treeNodeOnLevel : treeModel.getTreeNodes(level)) {
                TreeNodeSignature sig = treeNodeOnLevel.getSignature();
                ColumnSampleStrategy colStrat = m_columnSampleStrategies[i];
                ColumnSample cs = colStrat.getColumnSampleForTreeNode(sig);
                for (TreeAttributeColumnData col : cs) {
                    final int index = col.getMetaData().getAttributeIndex();
                    columnInLevelSampleCounts[level][index] += 1;
                }
                int splitAttIdx = treeNodeOnLevel.getSplitAttributeIndex();
                if (splitAttIdx >= 0) {
                    columnOnLevelCounts[level][splitAttIdx] += 1;
                }
            }
        }
    }
    for (int i = 0; i < nrAttributes; i++) {
        String name = columns[i].getMetaData().getAttributeName();
        int[] counts = new int[2 * REPORT_LEVEL];
        for (int level = 0; level < REPORT_LEVEL; level++) {
            counts[level] = columnOnLevelCounts[level][i];
            counts[REPORT_LEVEL + level] = columnInLevelSampleCounts[level][i];
        }
        DataRow row = new DefaultRow(name, counts);
        c.addRowToTable(row);
        exec.checkCanceled();
    }
    c.close();
    return c.getTable();
}
Also used : ColumnSampleStrategy(org.knime.base.node.mine.treeensemble2.sample.column.ColumnSampleStrategy) TreeAttributeColumnData(org.knime.base.node.mine.treeensemble2.data.TreeAttributeColumnData) BufferedDataContainer(org.knime.core.node.BufferedDataContainer) ColumnSample(org.knime.base.node.mine.treeensemble2.sample.column.ColumnSample) AbstractTreeNode(org.knime.base.node.mine.treeensemble2.model.AbstractTreeNode) TreeNodeSignature(org.knime.base.node.mine.treeensemble2.model.TreeNodeSignature) DataRow(org.knime.core.data.DataRow) DefaultRow(org.knime.core.data.def.DefaultRow)

Aggregations

AbstractTreeNode (org.knime.base.node.mine.treeensemble2.model.AbstractTreeNode)3 ExecutionException (java.util.concurrent.ExecutionException)2 Future (java.util.concurrent.Future)2 Semaphore (java.util.concurrent.Semaphore)2 AtomicReference (java.util.concurrent.atomic.AtomicReference)2 AbstractTreeModel (org.knime.base.node.mine.treeensemble2.model.AbstractTreeModel)2 TreeEnsembleModel (org.knime.base.node.mine.treeensemble2.model.TreeEnsembleModel)2 DataRow (org.knime.core.data.DataRow)2 CanceledExecutionException (org.knime.core.node.CanceledExecutionException)2 ExecutionMonitor (org.knime.core.node.ExecutionMonitor)2 ThreadPool (org.knime.core.util.ThreadPool)2 BufferedInputStream (java.io.BufferedInputStream)1 IOException (java.io.IOException)1 Callable (java.util.concurrent.Callable)1 RandomData (org.apache.commons.math.random.RandomData)1 Node (org.dmg.pmml.NodeDocument.Node)1 FilterColumnRow (org.knime.base.data.filter.column.FilterColumnRow)1 PredictorRecord (org.knime.base.node.mine.treeensemble2.data.PredictorRecord)1 TreeAttributeColumnData (org.knime.base.node.mine.treeensemble2.data.TreeAttributeColumnData)1 TreeMetaData (org.knime.base.node.mine.treeensemble2.data.TreeMetaData)1