use of org.knime.base.node.mine.treeensemble2.model.TreeEnsembleModelPortObject in project knime-core by knime.
the class RandomForestRegressionLearnerNodeModel method execute.
/**
* {@inheritDoc}
*/
@Override
protected PortObject[] execute(final PortObject[] inObjects, final ExecutionContext exec) throws Exception {
BufferedDataTable t = (BufferedDataTable) inObjects[0];
DataTableSpec spec = t.getDataTableSpec();
final FilterLearnColumnRearranger learnRearranger = m_configuration.filterLearnColumns(spec);
String warn = learnRearranger.getWarning();
BufferedDataTable learnTable = exec.createColumnRearrangeTable(t, learnRearranger, exec.createSubProgress(0.0));
DataTableSpec learnSpec = learnTable.getDataTableSpec();
TreeEnsembleModelPortObjectSpec ensembleSpec = m_configuration.createPortObjectSpec(learnSpec);
ExecutionMonitor readInExec = exec.createSubProgress(0.1);
ExecutionMonitor learnExec = exec.createSubProgress(0.8);
ExecutionMonitor outOfBagExec = exec.createSubProgress(0.1);
TreeDataCreator dataCreator = new TreeDataCreator(m_configuration, learnSpec, learnTable.getRowCount());
exec.setProgress("Reading data into memory");
TreeData data = dataCreator.readData(learnTable, m_configuration, readInExec);
m_hiliteRowSample = dataCreator.getDataRowsForHilite();
m_viewMessage = dataCreator.getViewMessage();
String dataCreationWarning = dataCreator.getAndClearWarningMessage();
if (dataCreationWarning != null) {
if (warn == null) {
warn = dataCreationWarning;
} else {
warn = warn + "\n" + dataCreationWarning;
}
}
readInExec.setProgress(1.0);
exec.setMessage("Learning trees");
TreeEnsembleLearner learner = new TreeEnsembleLearner(m_configuration, data);
TreeEnsembleModel model;
try {
model = learner.learnEnsemble(learnExec);
} catch (ExecutionException e) {
Throwable cause = e.getCause();
if (cause instanceof Exception) {
throw (Exception) cause;
}
throw e;
}
TreeEnsembleModelPortObject modelPortObject = TreeEnsembleModelPortObject.createPortObject(ensembleSpec, model, exec.createFileStore("TreeEnsemble"));
learnExec.setProgress(1.0);
exec.setMessage("Out of bag prediction");
TreeEnsemblePredictor outOfBagPredictor = createOutOfBagPredictor(ensembleSpec, modelPortObject, spec);
outOfBagPredictor.setOutofBagFilter(learner.getRowSamples(), data.getTargetColumn());
ColumnRearranger outOfBagRearranger = outOfBagPredictor.getPredictionRearranger();
BufferedDataTable outOfBagTable = exec.createColumnRearrangeTable(t, outOfBagRearranger, outOfBagExec);
BufferedDataTable colStatsTable = learner.createColumnStatisticTable(exec.createSubExecutionContext(0.0));
m_ensembleModelPortObject = modelPortObject;
if (warn != null) {
setWarningMessage(warn);
}
return new PortObject[] { outOfBagTable, colStatsTable, modelPortObject };
}
use of org.knime.base.node.mine.treeensemble2.model.TreeEnsembleModelPortObject in project knime-core by knime.
the class RandomForestRegressionLearnerNodeModel method createOutOfBagPredictor.
/**
* @param ensembleSpec
* @param ensembleModel
* @param inSpec
* @return
* @throws InvalidSettingsException
*/
private TreeEnsemblePredictor createOutOfBagPredictor(final TreeEnsembleModelPortObjectSpec ensembleSpec, final TreeEnsembleModelPortObject ensembleModel, final DataTableSpec inSpec) throws InvalidSettingsException {
String targetColumn = m_configuration.getTargetColumn();
TreeEnsemblePredictorConfiguration ooBConfig = new TreeEnsemblePredictorConfiguration(true, targetColumn);
String append = targetColumn + " (Out-of-bag)";
ooBConfig.setPredictionColumnName(append);
ooBConfig.setAppendPredictionConfidence(true);
ooBConfig.setAppendClassConfidences(true);
ooBConfig.setAppendModelCount(true);
return new TreeEnsemblePredictor(ensembleSpec, ensembleModel, inSpec, ooBConfig);
}
use of org.knime.base.node.mine.treeensemble2.model.TreeEnsembleModelPortObject in project knime-core by knime.
the class RandomForestRegressionPredictorNodeModel method execute.
/**
* {@inheritDoc}
*/
@Override
protected PortObject[] execute(final PortObject[] inObjects, final ExecutionContext exec) throws Exception {
TreeEnsembleModelPortObject model = (TreeEnsembleModelPortObject) inObjects[0];
TreeEnsembleModelPortObjectSpec modelSpec = model.getSpec();
BufferedDataTable data = (BufferedDataTable) inObjects[1];
DataTableSpec dataSpec = data.getDataTableSpec();
final TreeEnsemblePredictor pred = new TreeEnsemblePredictor(modelSpec, model, dataSpec, m_configuration);
ColumnRearranger rearranger = pred.getPredictionRearranger();
BufferedDataTable outTable = exec.createColumnRearrangeTable(data, rearranger, exec);
return new BufferedDataTable[] { outTable };
}
use of org.knime.base.node.mine.treeensemble2.model.TreeEnsembleModelPortObject in project knime-core by knime.
the class TreeEnsembleStatisticsNodeModel method execute.
@Override
protected PortObject[] execute(final PortObject[] inObjects, final ExecutionContext exec) throws Exception {
TreeEnsembleModel treeEnsemble = ((TreeEnsembleModelPortObject) inObjects[0]).getEnsembleModel();
EnsembleStatistic ensembleStats = new EnsembleStatistic(treeEnsemble);
DataContainer containerEnsembleStats = exec.createDataContainer(createEnsembleStatsSpec());
DataCell[] cells = new DataCell[7];
cells[0] = new IntCell(treeEnsemble.getNrModels());
cells[1] = new IntCell(ensembleStats.getMinLevel());
cells[2] = new IntCell(ensembleStats.getMaxLevel());
cells[3] = new DoubleCell(ensembleStats.getAvgLevel());
cells[4] = new IntCell(ensembleStats.getMinNumNodes());
cells[5] = new IntCell(ensembleStats.getMaxNumNodes());
cells[6] = new DoubleCell(ensembleStats.getAvgNumNodes());
containerEnsembleStats.addRowToTable(new DefaultRow(RowKey.createRowKey(0L), cells));
containerEnsembleStats.close();
DataContainer containerTreeStats = exec.createDataContainer(createTreeStatsSpec());
for (int i = 0; i < treeEnsemble.getNrModels(); i++) {
DataCell[] treeCells = new DataCell[2];
TreeStatistic treeStat = ensembleStats.getTreeStatistic(i);
treeCells[0] = new IntCell(treeStat.getNumLevels());
treeCells[1] = new IntCell(treeStat.getNumNodes());
containerTreeStats.addRowToTable(new DefaultRow(RowKey.createRowKey((long) i), treeCells));
}
containerTreeStats.close();
return new PortObject[] { (PortObject) containerEnsembleStats.getTable(), (PortObject) containerTreeStats.getTable() };
}
use of org.knime.base.node.mine.treeensemble2.model.TreeEnsembleModelPortObject in project knime-core by knime.
the class TreeEnsembleClassificationLearnerNodeModel method execute.
/**
* {@inheritDoc}
*/
@Override
protected PortObject[] execute(final PortObject[] inObjects, final ExecutionContext exec) throws Exception {
BufferedDataTable t = (BufferedDataTable) inObjects[0];
DataTableSpec spec = t.getDataTableSpec();
final FilterLearnColumnRearranger learnRearranger = m_configuration.filterLearnColumns(spec);
String warn = learnRearranger.getWarning();
BufferedDataTable learnTable = exec.createColumnRearrangeTable(t, learnRearranger, exec.createSubProgress(0.0));
DataTableSpec learnSpec = learnTable.getDataTableSpec();
TreeEnsembleModelPortObjectSpec ensembleSpec = m_configuration.createPortObjectSpec(learnSpec);
Map<String, DataCell> targetValueMap = ensembleSpec.getTargetColumnPossibleValueMap();
if (targetValueMap == null) {
throw new InvalidSettingsException("The target column does not " + "have possible values assigned. Most likely it " + "has too many different distinct values (learning an ID " + "column?) Fix it by preprocessing the table using " + "a \"Domain Calculator\".");
}
ExecutionMonitor readInExec = exec.createSubProgress(0.1);
ExecutionMonitor learnExec = exec.createSubProgress(0.8);
ExecutionMonitor outOfBagExec = exec.createSubProgress(0.1);
TreeDataCreator dataCreator = new TreeDataCreator(m_configuration, learnSpec, learnTable.getRowCount());
exec.setProgress("Reading data into memory");
TreeData data = dataCreator.readData(learnTable, m_configuration, readInExec);
m_hiliteRowSample = dataCreator.getDataRowsForHilite();
m_viewMessage = dataCreator.getViewMessage();
String dataCreationWarning = dataCreator.getAndClearWarningMessage();
if (dataCreationWarning != null) {
if (warn == null) {
warn = dataCreationWarning;
} else {
warn = warn + "\n" + dataCreationWarning;
}
}
readInExec.setProgress(1.0);
exec.setMessage("Learning trees");
TreeEnsembleLearner learner = new TreeEnsembleLearner(m_configuration, data);
TreeEnsembleModel model;
try {
model = learner.learnEnsemble(learnExec);
} catch (ExecutionException e) {
Throwable cause = e.getCause();
if (cause instanceof Exception) {
throw (Exception) cause;
}
throw e;
}
TreeEnsembleModelPortObject modelPortObject = TreeEnsembleModelPortObject.createPortObject(ensembleSpec, model, exec.createFileStore(UUID.randomUUID().toString() + ""));
learnExec.setProgress(1.0);
exec.setMessage("Out of bag prediction");
TreeEnsemblePredictor outOfBagPredictor = createOutOfBagPredictor(ensembleSpec, modelPortObject, spec);
outOfBagPredictor.setOutofBagFilter(learner.getRowSamples(), data.getTargetColumn());
ColumnRearranger outOfBagRearranger = outOfBagPredictor.getPredictionRearranger();
BufferedDataTable outOfBagTable = exec.createColumnRearrangeTable(t, outOfBagRearranger, outOfBagExec);
BufferedDataTable colStatsTable = learner.createColumnStatisticTable(exec.createSubExecutionContext(0.0));
m_ensembleModelPortObject = modelPortObject;
if (warn != null) {
setWarningMessage(warn);
}
return new PortObject[] { outOfBagTable, colStatsTable, modelPortObject };
}
Aggregations