use of org.knime.base.node.mine.treeensemble2.node.predictor.TreeEnsemblePredictor in project knime-core by knime.
the class TreeEnsembleRegressionLearnerNodeModel method execute.
/**
* {@inheritDoc}
*/
@Override
protected PortObject[] execute(final PortObject[] inObjects, final ExecutionContext exec) throws Exception {
BufferedDataTable t = (BufferedDataTable) inObjects[0];
DataTableSpec spec = t.getDataTableSpec();
final FilterLearnColumnRearranger learnRearranger = m_configuration.filterLearnColumns(spec);
String warn = learnRearranger.getWarning();
BufferedDataTable learnTable = exec.createColumnRearrangeTable(t, learnRearranger, exec.createSubProgress(0.0));
DataTableSpec learnSpec = learnTable.getDataTableSpec();
TreeEnsembleModelPortObjectSpec ensembleSpec = m_configuration.createPortObjectSpec(learnSpec);
ExecutionMonitor readInExec = exec.createSubProgress(0.1);
ExecutionMonitor learnExec = exec.createSubProgress(0.8);
ExecutionMonitor outOfBagExec = exec.createSubProgress(0.1);
TreeDataCreator dataCreator = new TreeDataCreator(m_configuration, learnSpec, learnTable.getRowCount());
exec.setProgress("Reading data into memory");
TreeData data = dataCreator.readData(learnTable, m_configuration, readInExec);
m_hiliteRowSample = dataCreator.getDataRowsForHilite();
m_viewMessage = dataCreator.getViewMessage();
String dataCreationWarning = dataCreator.getAndClearWarningMessage();
if (dataCreationWarning != null) {
if (warn == null) {
warn = dataCreationWarning;
} else {
warn = warn + "\n" + dataCreationWarning;
}
}
readInExec.setProgress(1.0);
exec.setMessage("Learning trees");
TreeEnsembleLearner learner = new TreeEnsembleLearner(m_configuration, data);
TreeEnsembleModel model;
try {
model = learner.learnEnsemble(learnExec);
} catch (ExecutionException e) {
Throwable cause = e.getCause();
if (cause instanceof Exception) {
throw (Exception) cause;
}
throw e;
}
TreeEnsembleModelPortObject modelPortObject = TreeEnsembleModelPortObject.createPortObject(ensembleSpec, model, exec.createFileStore(UUID.randomUUID().toString() + ""));
learnExec.setProgress(1.0);
exec.setMessage("Out of bag prediction");
TreeEnsemblePredictor outOfBagPredictor = createOutOfBagPredictor(ensembleSpec, modelPortObject, spec);
outOfBagPredictor.setOutofBagFilter(learner.getRowSamples(), data.getTargetColumn());
ColumnRearranger outOfBagRearranger = outOfBagPredictor.getPredictionRearranger();
BufferedDataTable outOfBagTable = exec.createColumnRearrangeTable(t, outOfBagRearranger, outOfBagExec);
BufferedDataTable colStatsTable = learner.createColumnStatisticTable(exec.createSubExecutionContext(0.0));
m_ensembleModelPortObject = modelPortObject;
if (warn != null) {
setWarningMessage(warn);
}
return new PortObject[] { outOfBagTable, colStatsTable, modelPortObject };
}
use of org.knime.base.node.mine.treeensemble2.node.predictor.TreeEnsemblePredictor in project knime-core by knime.
the class TreeEnsembleClassificationPredictorCellFactory method createFactory.
/**
* Creates a TreeEnsembleClassificationPredictorCellFactory from the provided <b>predictor</b>
* @param predictor
* @return an instance of TreeEnsembleClassificationPredictorCellFactory configured according to the settings of the provided
* <b>predictor<b>
* @throws InvalidSettingsException
*/
public static TreeEnsembleClassificationPredictorCellFactory createFactory(final TreeEnsemblePredictor predictor) throws InvalidSettingsException {
DataTableSpec testDataSpec = predictor.getDataSpec();
TreeEnsembleModelPortObjectSpec modelSpec = predictor.getModelSpec();
TreeEnsembleModelPortObject modelObject = predictor.getModelObject();
TreeEnsemblePredictorConfiguration configuration = predictor.getConfiguration();
UniqueNameGenerator nameGen = new UniqueNameGenerator(testDataSpec);
Map<String, DataCell> targetValueMap = modelSpec.getTargetColumnPossibleValueMap();
List<DataColumnSpec> newColsList = new ArrayList<DataColumnSpec>();
DataType targetColType = modelSpec.getTargetColumn().getType();
String predictionColName = configuration.getPredictionColumnName();
DataColumnSpec targetCol = nameGen.newColumn(predictionColName, targetColType);
newColsList.add(targetCol);
if (configuration.isAppendPredictionConfidence()) {
newColsList.add(nameGen.newColumn(targetCol.getName() + " (Confidence)", DoubleCell.TYPE));
}
if (configuration.isAppendClassConfidences()) {
final String targetColName = modelSpec.getTargetColumn().getName();
final String suffix = configuration.getSuffixForClassProbabilities();
// and this class is not called)
assert targetValueMap != null : "Target column has no possible values";
for (String v : targetValueMap.keySet()) {
final String colName = "P(" + targetColName + "=" + v + ")" + suffix;
newColsList.add(nameGen.newColumn(colName, DoubleCell.TYPE));
}
}
if (configuration.isAppendModelCount()) {
newColsList.add(nameGen.newColumn("model count", IntCell.TYPE));
}
// assigned
assert modelObject == null || targetValueMap != null : "Target values must be known during execution";
DataColumnSpec[] newCols = newColsList.toArray(new DataColumnSpec[newColsList.size()]);
int[] learnColumnInRealDataIndices = modelSpec.calculateFilterIndices(testDataSpec);
final Map<String, Integer> targetValueToIndexMap = new HashMap<String, Integer>(targetValueMap.size());
Iterator<String> targetValIterator = targetValueMap.keySet().iterator();
for (int i = 0; i < targetValueMap.size(); i++) {
targetValueToIndexMap.put(targetValIterator.next(), i);
}
final VotingFactory votingFactory;
if (configuration.isUseSoftVoting()) {
votingFactory = new SoftVotingFactory(targetValueToIndexMap);
} else {
votingFactory = new HardVotingFactory(targetValueToIndexMap);
}
return new TreeEnsembleClassificationPredictorCellFactory(predictor, targetValueMap, newCols, learnColumnInRealDataIndices, votingFactory);
}
use of org.knime.base.node.mine.treeensemble2.node.predictor.TreeEnsemblePredictor in project knime-core by knime.
the class TreeEnsembleClassificationPredictorNodeModel method configure.
/**
* {@inheritDoc}
*/
@Override
protected PortObjectSpec[] configure(final PortObjectSpec[] inSpecs) throws InvalidSettingsException {
TreeEnsembleModelPortObjectSpec modelSpec = (TreeEnsembleModelPortObjectSpec) inSpecs[0];
String targetColName = modelSpec.getTargetColumn().getName();
if (m_configuration == null) {
m_configuration = TreeEnsemblePredictorConfiguration.createDefault(false, targetColName);
} else if (!m_configuration.isChangePredictionColumnName()) {
m_configuration.setPredictionColumnName(TreeEnsemblePredictorConfiguration.getPredictColumnName(targetColName));
}
modelSpec.assertTargetTypeMatches(false);
DataTableSpec dataSpec = (DataTableSpec) inSpecs[1];
final TreeEnsemblePredictor pred = new TreeEnsemblePredictor(modelSpec, null, dataSpec, m_configuration);
ColumnRearranger rearranger = pred.getPredictionRearranger();
// rearranger may be null if confidence values are appended but the
// model does not have a list of possible target values
DataTableSpec outSpec = rearranger != null ? rearranger.createSpec() : null;
return new DataTableSpec[] { outSpec };
}
use of org.knime.base.node.mine.treeensemble2.node.predictor.TreeEnsemblePredictor in project knime-core by knime.
the class TreeEnsembleClassificationPredictorNodeModel method execute.
/**
* {@inheritDoc}
*/
@Override
protected PortObject[] execute(final PortObject[] inObjects, final ExecutionContext exec) throws Exception {
TreeEnsembleModelPortObject model = (TreeEnsembleModelPortObject) inObjects[0];
TreeEnsembleModelPortObjectSpec modelSpec = model.getSpec();
BufferedDataTable data = (BufferedDataTable) inObjects[1];
DataTableSpec dataSpec = data.getDataTableSpec();
m_configuration.checkSoftVotingSettingForModel(model).ifPresent(s -> setWarningMessage(s));
final TreeEnsemblePredictor pred = new TreeEnsemblePredictor(modelSpec, model, dataSpec, m_configuration);
ColumnRearranger rearranger = pred.getPredictionRearranger();
BufferedDataTable outTable = exec.createColumnRearrangeTable(data, rearranger, exec);
return new BufferedDataTable[] { outTable };
}
use of org.knime.base.node.mine.treeensemble2.node.predictor.TreeEnsemblePredictor in project knime-core by knime.
the class TreeEnsembleRegressionPredictorNodeModel method configure.
/**
* {@inheritDoc}
*/
@Override
protected PortObjectSpec[] configure(final PortObjectSpec[] inSpecs) throws InvalidSettingsException {
TreeEnsembleModelPortObjectSpec modelSpec = (TreeEnsembleModelPortObjectSpec) inSpecs[0];
String targetColName = modelSpec.getTargetColumn().getName();
if (m_configuration == null) {
m_configuration = TreeEnsemblePredictorConfiguration.createDefault(false, targetColName);
} else if (!m_configuration.isChangePredictionColumnName()) {
m_configuration.setPredictionColumnName(TreeEnsemblePredictorConfiguration.getPredictColumnName(targetColName));
}
modelSpec.assertTargetTypeMatches(true);
DataTableSpec dataSpec = (DataTableSpec) inSpecs[1];
final TreeEnsemblePredictor pred = new TreeEnsemblePredictor(modelSpec, null, dataSpec, m_configuration);
ColumnRearranger rearranger = pred.getPredictionRearranger();
// rearranger may be null if confidence values are appended but the
// model does not have a list of possible target values
DataTableSpec outSpec = rearranger != null ? rearranger.createSpec() : null;
return new DataTableSpec[] { outSpec };
}
Aggregations