use of org.knime.base.node.mine.treeensemble2.node.proximity.ProximityMatrix in project knime-core by knime.
the class RandomForestProximityNodeModel method execute.
@Override
protected BufferedDataTable[] execute(final PortObject[] inObjects, final ExecutionContext exec) throws Exception {
TreeEnsembleModelPortObject model = (TreeEnsembleModelPortObject) inObjects[0];
BufferedDataTable table1 = (BufferedDataTable) inObjects[1];
BufferedDataTable table2 = (BufferedDataTable) inObjects[2];
BufferedDataTable[] tables;
if (table2 != null) {
tables = new BufferedDataTable[] { table1, table2 };
} else {
tables = new BufferedDataTable[] { table1 };
}
ExecutionContext calcExec = exec.createSubExecutionContext(0.7);
ExecutionContext writeExec = exec.createSubExecutionContext(0.3);
exec.setMessage("Calculating Proximity");
ProximityMatrix pm = null;
ProximityMeasure proximityMeasure = ProximityMeasure.valueOf(m_proximityMeasure.getStringValue());
switch(proximityMeasure) {
case PathProximity:
pm = new PathProximity(tables, model).calculatePathProximities(calcExec);
break;
case Proximity:
pm = Proximity.calcProximities(tables, model, calcExec);
break;
default:
throw new IllegalStateException("Illegal proximity measure encountered.");
}
exec.setMessage("Writing");
return new BufferedDataTable[] { pm.createTable(writeExec) };
}
use of org.knime.base.node.mine.treeensemble2.node.proximity.ProximityMatrix in project knime-core by knime.
the class Proximity method calcProximities.
public static ProximityMatrix calcProximities(final BufferedDataTable[] tables, final TreeEnsembleModelPortObject modelPortObject, final ExecutionContext exec) throws InvalidSettingsException, InterruptedException, ExecutionException, CanceledExecutionException {
ProximityMatrix proximityMatrix = null;
boolean optionalTable = false;
switch(tables.length) {
case 1:
if (tables[0].size() <= 65500) {
proximityMatrix = new SingleTableProximityMatrix(tables[0]);
} else {
// this is unfortunate and we should maybe think of a different solution
proximityMatrix = new TwoTablesProximityMatrix(tables[0], tables[0]);
}
break;
case 2:
optionalTable = true;
proximityMatrix = new TwoTablesProximityMatrix(tables[0], tables[1]);
break;
default:
throw new IllegalArgumentException("Currently only up to two tables are supported.");
}
final TreeEnsembleModelPortObjectSpec modelSpec = modelPortObject.getSpec();
final TreeEnsembleModel ensembleModel = modelPortObject.getEnsembleModel();
int[][] learnColIndicesInTables = null;
if (optionalTable) {
learnColIndicesInTables = new int[][] { modelSpec.calculateFilterIndices(tables[0].getDataTableSpec()), modelSpec.calculateFilterIndices(tables[1].getDataTableSpec()) };
} else {
learnColIndicesInTables = new int[][] { modelSpec.calculateFilterIndices(tables[0].getDataTableSpec()) };
}
final ThreadPool tp = KNIMEConstants.GLOBAL_THREAD_POOL;
final int procCount = 3 * Runtime.getRuntime().availableProcessors() / 2;
final Semaphore semaphore = new Semaphore(procCount);
final AtomicReference<Throwable> proxThrowableRef = new AtomicReference<Throwable>();
final int nrTrees = ensembleModel.getNrModels();
final Future<?>[] calcFutures = new Future<?>[nrTrees];
exec.setProgress(0, "Starting proximity calculation per tree.");
for (int i = 0; i < nrTrees; i++) {
semaphore.acquire();
finishedTree(i, exec, nrTrees);
checkThrowable(proxThrowableRef);
AbstractTreeModel treeModel = ensembleModel.getTreeModel(i);
ExecutionMonitor subExec = exec.createSubProgress(0.0);
if (optionalTable) {
calcFutures[i] = tp.enqueue(new TwoTablesProximityCalcRunnable(proximityMatrix, tables, learnColIndicesInTables, treeModel, modelPortObject, semaphore, proxThrowableRef, subExec));
} else {
calcFutures[i] = tp.enqueue(new SingleTableProximityCalcRunnable(proximityMatrix, tables, learnColIndicesInTables, treeModel, modelPortObject, semaphore, proxThrowableRef, subExec));
}
}
for (int i = 0; i < procCount; i++) {
semaphore.acquire();
finishedTree(nrTrees - procCount + i, exec, nrTrees);
}
for (Future<?> future : calcFutures) {
try {
future.get();
} catch (Exception e) {
proxThrowableRef.compareAndSet(null, e);
}
}
checkThrowable(proxThrowableRef);
proximityMatrix.normalize(1.0 / nrTrees);
return proximityMatrix;
}
use of org.knime.base.node.mine.treeensemble2.node.proximity.ProximityMatrix in project knime-core by knime.
the class RandomForestNearestNeighborNodeModel method execute.
/**
* {@inheritDoc}
*/
@Override
protected BufferedDataTable[] execute(final PortObject[] inObjects, final ExecutionContext exec) throws Exception {
TreeEnsembleModelPortObject ensembleModel = (TreeEnsembleModelPortObject) inObjects[0];
boolean optionalTable = inObjects[2] != null;
BufferedDataTable[] tables = new BufferedDataTable[optionalTable ? 2 : 1];
tables[0] = (BufferedDataTable) inObjects[1];
if (optionalTable) {
tables[1] = (BufferedDataTable) inObjects[2];
}
ExecutionContext proxExec = exec.createSubExecutionContext(0.6);
ExecutionContext nnExec = exec.createSubExecutionContext(0.4);
exec.setMessage("Calculating");
ProximityMatrix proximityMatrix = Proximity.calcProximities(tables, ensembleModel, proxExec);
ProximityMeasure proximityMeasure = ProximityMeasure.valueOf(m_proximityMeasure.getStringValue());
switch(proximityMeasure) {
case Proximity:
proximityMatrix = Proximity.calcProximities(tables, ensembleModel, proxExec);
break;
case PathProximity:
proximityMatrix = new PathProximity(tables, ensembleModel).calculatePathProximities(proxExec);
break;
default:
throw new IllegalStateException("Encountered unknown proximity measure.");
}
exec.setMessage("Calculating nearest neighbors");
int k = m_numNearestNeighbors.getIntValue();
return proximityMatrix.getNearestNeighbors(nnExec, k);
}
Aggregations