use of org.knime.base.node.mine.treeensemble2.model.TreeEnsembleModelPortObject in project knime-core by knime.
the class Proximity method calcProximities.
public static ProximityMatrix calcProximities(final BufferedDataTable[] tables, final TreeEnsembleModelPortObject modelPortObject, final ExecutionContext exec) throws InvalidSettingsException, InterruptedException, ExecutionException, CanceledExecutionException {
ProximityMatrix proximityMatrix = null;
boolean optionalTable = false;
switch(tables.length) {
case 1:
if (tables[0].size() <= 65500) {
proximityMatrix = new SingleTableProximityMatrix(tables[0]);
} else {
// this is unfortunate and we should maybe think of a different solution
proximityMatrix = new TwoTablesProximityMatrix(tables[0], tables[0]);
}
break;
case 2:
optionalTable = true;
proximityMatrix = new TwoTablesProximityMatrix(tables[0], tables[1]);
break;
default:
throw new IllegalArgumentException("Currently only up to two tables are supported.");
}
final TreeEnsembleModelPortObjectSpec modelSpec = modelPortObject.getSpec();
final TreeEnsembleModel ensembleModel = modelPortObject.getEnsembleModel();
int[][] learnColIndicesInTables = null;
if (optionalTable) {
learnColIndicesInTables = new int[][] { modelSpec.calculateFilterIndices(tables[0].getDataTableSpec()), modelSpec.calculateFilterIndices(tables[1].getDataTableSpec()) };
} else {
learnColIndicesInTables = new int[][] { modelSpec.calculateFilterIndices(tables[0].getDataTableSpec()) };
}
final ThreadPool tp = KNIMEConstants.GLOBAL_THREAD_POOL;
final int procCount = 3 * Runtime.getRuntime().availableProcessors() / 2;
final Semaphore semaphore = new Semaphore(procCount);
final AtomicReference<Throwable> proxThrowableRef = new AtomicReference<Throwable>();
final int nrTrees = ensembleModel.getNrModels();
final Future<?>[] calcFutures = new Future<?>[nrTrees];
exec.setProgress(0, "Starting proximity calculation per tree.");
for (int i = 0; i < nrTrees; i++) {
semaphore.acquire();
finishedTree(i, exec, nrTrees);
checkThrowable(proxThrowableRef);
AbstractTreeModel treeModel = ensembleModel.getTreeModel(i);
ExecutionMonitor subExec = exec.createSubProgress(0.0);
if (optionalTable) {
calcFutures[i] = tp.enqueue(new TwoTablesProximityCalcRunnable(proximityMatrix, tables, learnColIndicesInTables, treeModel, modelPortObject, semaphore, proxThrowableRef, subExec));
} else {
calcFutures[i] = tp.enqueue(new SingleTableProximityCalcRunnable(proximityMatrix, tables, learnColIndicesInTables, treeModel, modelPortObject, semaphore, proxThrowableRef, subExec));
}
}
for (int i = 0; i < procCount; i++) {
semaphore.acquire();
finishedTree(nrTrees - procCount + i, exec, nrTrees);
}
for (Future<?> future : calcFutures) {
try {
future.get();
} catch (Exception e) {
proxThrowableRef.compareAndSet(null, e);
}
}
checkThrowable(proxThrowableRef);
proximityMatrix.normalize(1.0 / nrTrees);
return proximityMatrix;
}
use of org.knime.base.node.mine.treeensemble2.model.TreeEnsembleModelPortObject in project knime-core by knime.
the class RandomForestNearestNeighborNodeModel method execute.
/**
* {@inheritDoc}
*/
@Override
protected BufferedDataTable[] execute(final PortObject[] inObjects, final ExecutionContext exec) throws Exception {
TreeEnsembleModelPortObject ensembleModel = (TreeEnsembleModelPortObject) inObjects[0];
boolean optionalTable = inObjects[2] != null;
BufferedDataTable[] tables = new BufferedDataTable[optionalTable ? 2 : 1];
tables[0] = (BufferedDataTable) inObjects[1];
if (optionalTable) {
tables[1] = (BufferedDataTable) inObjects[2];
}
ExecutionContext proxExec = exec.createSubExecutionContext(0.6);
ExecutionContext nnExec = exec.createSubExecutionContext(0.4);
exec.setMessage("Calculating");
ProximityMatrix proximityMatrix = Proximity.calcProximities(tables, ensembleModel, proxExec);
ProximityMeasure proximityMeasure = ProximityMeasure.valueOf(m_proximityMeasure.getStringValue());
switch(proximityMeasure) {
case Proximity:
proximityMatrix = Proximity.calcProximities(tables, ensembleModel, proxExec);
break;
case PathProximity:
proximityMatrix = new PathProximity(tables, ensembleModel).calculatePathProximities(proxExec);
break;
default:
throw new IllegalStateException("Encountered unknown proximity measure.");
}
exec.setMessage("Calculating nearest neighbors");
int k = m_numNearestNeighbors.getIntValue();
return proximityMatrix.getNearestNeighbors(nnExec, k);
}
use of org.knime.base.node.mine.treeensemble2.model.TreeEnsembleModelPortObject in project knime-core by knime.
the class RandomForestClassificationLearnerNodeModel method createOutOfBagPredictor.
/**
* @param ensembleSpec
* @param ensembleModel
* @param inSpec
* @return
* @throws InvalidSettingsException
*/
private TreeEnsemblePredictor createOutOfBagPredictor(final TreeEnsembleModelPortObjectSpec ensembleSpec, final TreeEnsembleModelPortObject ensembleModel, final DataTableSpec inSpec) throws InvalidSettingsException {
String targetColumn = m_configuration.getTargetColumn();
TreeEnsemblePredictorConfiguration ooBConfig = new TreeEnsemblePredictorConfiguration(false, targetColumn);
String append = targetColumn + " (Out-of-bag)";
ooBConfig.setPredictionColumnName(append);
ooBConfig.setAppendPredictionConfidence(true);
ooBConfig.setAppendClassConfidences(true);
ooBConfig.setAppendModelCount(true);
return new TreeEnsemblePredictor(ensembleSpec, ensembleModel, inSpec, ooBConfig);
}
use of org.knime.base.node.mine.treeensemble2.model.TreeEnsembleModelPortObject in project knime-core by knime.
the class RandomForestClassificationLearnerNodeModel method execute.
/**
* {@inheritDoc}
*/
@Override
protected PortObject[] execute(final PortObject[] inObjects, final ExecutionContext exec) throws Exception {
BufferedDataTable t = (BufferedDataTable) inObjects[0];
DataTableSpec spec = t.getDataTableSpec();
final FilterLearnColumnRearranger learnRearranger = m_configuration.filterLearnColumns(spec);
String warn = learnRearranger.getWarning();
BufferedDataTable learnTable = exec.createColumnRearrangeTable(t, learnRearranger, exec.createSubProgress(0.0));
DataTableSpec learnSpec = learnTable.getDataTableSpec();
TreeEnsembleModelPortObjectSpec ensembleSpec = m_configuration.createPortObjectSpec(learnSpec);
Map<String, DataCell> targetValueMap = ensembleSpec.getTargetColumnPossibleValueMap();
if (targetValueMap == null) {
throw new InvalidSettingsException("The target column does not " + "have possible values assigned. Most likely it " + "has too many different distinct values (learning an ID " + "column?) Fix it by preprocessing the table using " + "a \"Domain Calculator\".");
}
ExecutionMonitor readInExec = exec.createSubProgress(0.1);
ExecutionMonitor learnExec = exec.createSubProgress(0.8);
ExecutionMonitor outOfBagExec = exec.createSubProgress(0.1);
TreeDataCreator dataCreator = new TreeDataCreator(m_configuration, learnSpec, learnTable.getRowCount());
exec.setProgress("Reading data into memory");
TreeData data = dataCreator.readData(learnTable, m_configuration, readInExec);
m_hiliteRowSample = dataCreator.getDataRowsForHilite();
m_viewMessage = dataCreator.getViewMessage();
String dataCreationWarning = dataCreator.getAndClearWarningMessage();
if (dataCreationWarning != null) {
if (warn == null) {
warn = dataCreationWarning;
} else {
warn = warn + "\n" + dataCreationWarning;
}
}
readInExec.setProgress(1.0);
exec.setMessage("Learning trees");
// Use xgboost missing value handling
m_configuration.setMissingValueHandling(MissingValueHandling.XGBoost);
TreeEnsembleLearner learner = new TreeEnsembleLearner(m_configuration, data);
TreeEnsembleModel model;
try {
model = learner.learnEnsemble(learnExec);
} catch (ExecutionException e) {
Throwable cause = e.getCause();
if (cause instanceof Exception) {
throw (Exception) cause;
}
throw e;
}
TreeEnsembleModelPortObject modelPortObject = TreeEnsembleModelPortObject.createPortObject(ensembleSpec, model, exec.createFileStore("TreeEnsemble"));
learnExec.setProgress(1.0);
exec.setMessage("Out of bag prediction");
TreeEnsemblePredictor outOfBagPredictor = createOutOfBagPredictor(ensembleSpec, modelPortObject, spec);
outOfBagPredictor.setOutofBagFilter(learner.getRowSamples(), data.getTargetColumn());
ColumnRearranger outOfBagRearranger = outOfBagPredictor.getPredictionRearranger();
BufferedDataTable outOfBagTable = exec.createColumnRearrangeTable(t, outOfBagRearranger, outOfBagExec);
BufferedDataTable colStatsTable = learner.createColumnStatisticTable(exec.createSubExecutionContext(0.0));
m_ensembleModelPortObject = modelPortObject;
if (warn != null) {
setWarningMessage(warn);
}
return new PortObject[] { outOfBagTable, colStatsTable, modelPortObject };
}
use of org.knime.base.node.mine.treeensemble2.model.TreeEnsembleModelPortObject in project knime-core by knime.
the class RandomForestClassificationPredictorNodeModel method execute.
/**
* {@inheritDoc}
*/
@Override
protected PortObject[] execute(final PortObject[] inObjects, final ExecutionContext exec) throws Exception {
TreeEnsembleModelPortObject model = (TreeEnsembleModelPortObject) inObjects[0];
TreeEnsembleModelPortObjectSpec modelSpec = model.getSpec();
BufferedDataTable data = (BufferedDataTable) inObjects[1];
DataTableSpec dataSpec = data.getDataTableSpec();
m_configuration.checkSoftVotingSettingForModel(model).ifPresent(s -> setWarningMessage(s));
final TreeEnsemblePredictor pred = new TreeEnsemblePredictor(modelSpec, model, dataSpec, m_configuration);
ColumnRearranger rearranger = pred.getPredictionRearranger();
BufferedDataTable outTable = exec.createColumnRearrangeTable(data, rearranger, exec);
return new BufferedDataTable[] { outTable };
}
Aggregations