use of org.knime.core.node.ExecutionMonitor in project knime-core by knime.
the class TreeEnsembleLearner method learnEnsemble.
public TreeEnsembleModel learnEnsemble(final ExecutionMonitor exec) throws CanceledExecutionException, ExecutionException {
final int nrModels = m_config.getNrModels();
final RandomData rd = m_config.createRandomData();
final ThreadPool tp = KNIMEConstants.GLOBAL_THREAD_POOL;
final AtomicReference<Throwable> learnThrowableRef = new AtomicReference<Throwable>();
@SuppressWarnings("unchecked") final Future<TreeLearnerResult>[] modelFutures = new Future[nrModels];
final int procCount = 3 * Runtime.getRuntime().availableProcessors() / 2;
final Semaphore semaphore = new Semaphore(procCount);
Callable<TreeLearnerResult[]> learnCallable = new Callable<TreeLearnerResult[]>() {
@Override
public TreeLearnerResult[] call() throws Exception {
final TreeLearnerResult[] results = new TreeLearnerResult[nrModels];
for (int i = 0; i < nrModels; i++) {
semaphore.acquire();
finishedTree(i - procCount, exec);
checkThrowable(learnThrowableRef);
RandomData rdSingle = TreeEnsembleLearnerConfiguration.createRandomData(rd.nextLong(Long.MIN_VALUE, Long.MAX_VALUE));
ExecutionMonitor subExec = exec.createSubProgress(0.0);
modelFutures[i] = tp.enqueue(new TreeLearnerCallable(subExec, rdSingle, learnThrowableRef, semaphore));
}
for (int i = 0; i < procCount; i++) {
semaphore.acquire();
finishedTree(nrModels - 1 + i - procCount, exec);
}
for (int i = 0; i < nrModels; i++) {
try {
results[i] = modelFutures[i].get();
} catch (Exception e) {
learnThrowableRef.compareAndSet(null, e);
}
}
return results;
}
private void finishedTree(final int treeIndex, final ExecutionMonitor progMon) {
if (treeIndex > 0) {
progMon.setProgress(treeIndex / (double) nrModels, "Tree " + treeIndex + "/" + nrModels);
}
}
};
TreeLearnerResult[] modelResults = tp.runInvisible(learnCallable);
checkThrowable(learnThrowableRef);
AbstractTreeModel[] models = new AbstractTreeModel[nrModels];
m_rowSamples = new RowSample[nrModels];
m_columnSampleStrategies = new ColumnSampleStrategy[nrModels];
for (int i = 0; i < nrModels; i++) {
models[i] = modelResults[i].m_treeModel;
m_rowSamples[i] = modelResults[i].m_rowSample;
m_columnSampleStrategies[i] = modelResults[i].m_rootColumnSampleStrategy;
}
m_ensembleModel = new TreeEnsembleModel(m_config, m_data.getMetaData(), models, m_data.getTreeType());
return m_ensembleModel;
}
use of org.knime.core.node.ExecutionMonitor in project knime-core by knime.
the class TreeEnsembleRegressionLearnerNodeModel method saveInternals.
/**
* {@inheritDoc}
*/
@Override
protected void saveInternals(final File nodeInternDir, final ExecutionMonitor exec) throws IOException, CanceledExecutionException {
File file;
ExecutionMonitor sub;
if (m_oldStyleEnsembleModel_deprecated != null) {
// old workflow (<2.10) loaded and saved ...
file = new File(nodeInternDir, INTERNAL_TREES_FILE);
OutputStream out = new GZIPOutputStream(new FileOutputStream(file));
sub = exec.createSubProgress(0.2);
m_oldStyleEnsembleModel_deprecated.save(out, sub);
out.close();
}
if (m_hiliteRowSample != null) {
file = new File(nodeInternDir, INTERNAL_DATASAMPLE_FILE);
sub = exec.createSubProgress(0.2);
DataContainer.writeToZip(m_hiliteRowSample, file, sub);
}
if (m_viewMessage != null) {
file = new File(nodeInternDir, INTERNAL_INFO_FILE);
NodeSettings sets = new NodeSettings("ensembleData");
sets.addString("view_warning", m_viewMessage);
sets.saveToXML(new FileOutputStream(file));
}
}
use of org.knime.core.node.ExecutionMonitor in project knime-core by knime.
the class TreeEnsembleRegressionLearnerNodeModel method execute.
/**
* {@inheritDoc}
*/
@Override
protected PortObject[] execute(final PortObject[] inObjects, final ExecutionContext exec) throws Exception {
BufferedDataTable t = (BufferedDataTable) inObjects[0];
DataTableSpec spec = t.getDataTableSpec();
final FilterLearnColumnRearranger learnRearranger = m_configuration.filterLearnColumns(spec);
String warn = learnRearranger.getWarning();
BufferedDataTable learnTable = exec.createColumnRearrangeTable(t, learnRearranger, exec.createSubProgress(0.0));
DataTableSpec learnSpec = learnTable.getDataTableSpec();
TreeEnsembleModelPortObjectSpec ensembleSpec = m_configuration.createPortObjectSpec(learnSpec);
ExecutionMonitor readInExec = exec.createSubProgress(0.1);
ExecutionMonitor learnExec = exec.createSubProgress(0.8);
ExecutionMonitor outOfBagExec = exec.createSubProgress(0.1);
TreeDataCreator dataCreator = new TreeDataCreator(m_configuration, learnSpec, learnTable.getRowCount());
exec.setProgress("Reading data into memory");
TreeData data = dataCreator.readData(learnTable, m_configuration, readInExec);
m_hiliteRowSample = dataCreator.getDataRowsForHilite();
m_viewMessage = dataCreator.getViewMessage();
String dataCreationWarning = dataCreator.getAndClearWarningMessage();
if (dataCreationWarning != null) {
if (warn == null) {
warn = dataCreationWarning;
} else {
warn = warn + "\n" + dataCreationWarning;
}
}
readInExec.setProgress(1.0);
exec.setMessage("Learning trees");
TreeEnsembleLearner learner = new TreeEnsembleLearner(m_configuration, data);
TreeEnsembleModel model;
try {
model = learner.learnEnsemble(learnExec);
} catch (ExecutionException e) {
Throwable cause = e.getCause();
if (cause instanceof Exception) {
throw (Exception) cause;
}
throw e;
}
TreeEnsembleModelPortObject modelPortObject = new TreeEnsembleModelPortObject(ensembleSpec, model);
learnExec.setProgress(1.0);
exec.setMessage("Out of bag prediction");
TreeEnsemblePredictor outOfBagPredictor = createOutOfBagPredictor(ensembleSpec, modelPortObject, spec);
outOfBagPredictor.setOutofBagFilter(learner.getRowSamples(), data.getTargetColumn());
ColumnRearranger outOfBagRearranger = outOfBagPredictor.getPredictionRearranger();
BufferedDataTable outOfBagTable = exec.createColumnRearrangeTable(t, outOfBagRearranger, outOfBagExec);
BufferedDataTable colStatsTable = learner.createColumnStatisticTable(exec.createSubExecutionContext(0.0));
m_ensembleModelPortObject = modelPortObject;
if (warn != null) {
setWarningMessage(warn);
}
return new PortObject[] { outOfBagTable, colStatsTable, modelPortObject };
}
use of org.knime.core.node.ExecutionMonitor in project knime-core by knime.
the class IrlsLearner method learn.
/**
* {@inheritDoc}
*/
@Override
public LogRegLearnerResult learn(final TrainingData<ClassificationTrainingRow> trainingData, final ExecutionMonitor exec) throws CanceledExecutionException, InvalidSettingsException {
exec.checkCanceled();
int iter = 0;
boolean converged = false;
final int tcC = trainingData.getTargetDimension() + 1;
final int rC = trainingData.getFeatureCount() - 1;
final RealMatrix beta = MatrixUtils.createRealMatrix(1, (tcC - 1) * (rC + 1));
Double loglike = 0.0;
Double loglikeOld = 0.0;
exec.setMessage("Iterative optimization. Processing iteration 1.");
// main loop
while (iter < m_maxIter && !converged) {
RealMatrix betaOld = beta.copy();
loglikeOld = loglike;
// Do heavy work in a separate thread which allows to interrupt it
// note the queue may block if no more threads are available (e.g. thread count = 1)
// as soon as we stall in 'get' this thread reduces the number of running thread
Future<Double> future = ThreadPool.currentPool().enqueue(new Callable<Double>() {
@Override
public Double call() throws Exception {
final ExecutionMonitor progMon = exec.createSubProgress(1.0 / m_maxIter);
irlsRls(trainingData, beta, rC, tcC, progMon);
progMon.setProgress(1.0);
return likelihood(trainingData.iterator(), beta, rC, tcC, exec);
}
});
try {
loglike = future.get();
} catch (InterruptedException e) {
future.cancel(true);
exec.checkCanceled();
throw new RuntimeException(e);
} catch (ExecutionException e) {
if (e.getCause() instanceof RuntimeException) {
throw (RuntimeException) e.getCause();
} else {
throw new RuntimeException(e.getCause());
}
}
if (Double.isInfinite(loglike) || Double.isNaN(loglike)) {
throw new RuntimeException(FAILING_MSG);
}
exec.checkCanceled();
// test for decreasing likelihood
while ((Double.isInfinite(loglike) || Double.isNaN(loglike) || loglike < loglikeOld) && iter > 0) {
converged = true;
for (int k = 0; k < beta.getColumnDimension(); k++) {
if (abs(beta.getEntry(0, k) - betaOld.getEntry(0, k)) > m_eps * abs(betaOld.getEntry(0, k))) {
converged = false;
break;
}
}
if (converged) {
break;
}
// half the step size of beta
beta.setSubMatrix((beta.add(betaOld)).scalarMultiply(0.5).getData(), 0, 0);
exec.checkCanceled();
loglike = likelihood(trainingData.iterator(), beta, rC, tcC, exec);
exec.checkCanceled();
}
// test for convergence
converged = true;
for (int k = 0; k < beta.getColumnDimension(); k++) {
if (abs(beta.getEntry(0, k) - betaOld.getEntry(0, k)) > m_eps * abs(betaOld.getEntry(0, k))) {
converged = false;
break;
}
}
iter++;
LOGGER.debug("#Iterations: " + iter);
LOGGER.debug("Log Likelihood: " + loglike);
StringBuilder betaBuilder = new StringBuilder();
for (int i = 0; i < beta.getColumnDimension() - 1; i++) {
betaBuilder.append(Double.toString(beta.getEntry(0, i)));
betaBuilder.append(", ");
}
if (beta.getColumnDimension() > 0) {
betaBuilder.append(Double.toString(beta.getEntry(0, beta.getColumnDimension() - 1)));
}
LOGGER.debug("beta: " + betaBuilder.toString());
exec.checkCanceled();
exec.setMessage("Iterative optimization. #Iterations: " + iter + " | Log-likelihood: " + DoubleFormat.formatDouble(loglike) + ". Processing iteration " + (iter + 1) + ".");
}
StringBuilder warnBuilder = new StringBuilder();
if (iter >= m_maxIter) {
warnBuilder.append("The algorithm did not reach convergence after the specified number of epochs. " + "Setting the epoch limit higher might result in a better model.");
}
// The covariance matrix
RealMatrix covMat = null;
if (m_calcCovMatrix) {
try {
covMat = new QRDecomposition(A).getSolver().getInverse().scalarMultiply(-1);
} catch (SingularMatrixException sme) {
if (warnBuilder.length() > 0) {
warnBuilder.append("\n");
}
warnBuilder.append("The covariance matrix could not be calculated because the" + " observed fisher information matrix was singular.");
}
}
RealMatrix betaMat = MatrixUtils.createRealMatrix(tcC - 1, rC + 1);
for (int i = 0; i < beta.getColumnDimension(); i++) {
int r = i / (rC + 1);
int c = i % (rC + 1);
betaMat.setEntry(r, c, beta.getEntry(0, i));
}
m_warning = warnBuilder.length() > 0 ? warnBuilder.toString() : null;
return new LogRegLearnerResult(betaMat, covMat, iter, loglike);
}
use of org.knime.core.node.ExecutionMonitor in project knime-core by knime.
the class EnrichmentPlotterModel method execute.
/**
* {@inheritDoc}
*/
@Override
protected BufferedDataTable[] execute(final BufferedDataTable[] inData, final ExecutionContext exec) throws Exception {
final double rowCount = inData[0].size();
final BufferedDataContainer areaOutCont = exec.createDataContainer(AREA_OUT_SPEC);
final BufferedDataContainer discrateOutCont = exec.createDataContainer(getDiscrateOutSpec());
final double[] fractionSizes = m_settings.getFractionSizes();
for (int i = 0; i < m_settings.getCurveCount(); i++) {
final ExecutionMonitor sexec = exec.createSubProgress(1.0 / m_settings.getCurveCount());
exec.setMessage("Generating curve " + (i + 1));
final Curve c = m_settings.getCurve(i);
final Helper[] curve = new Helper[KnowsRowCountTable.checkRowCount(inData[0].size())];
final int sortIndex = inData[0].getDataTableSpec().findColumnIndex(c.getSortColumn());
final int actIndex = inData[0].getDataTableSpec().findColumnIndex(c.getActivityColumn());
int k = 0, maxK = 0;
for (DataRow row : inData[0]) {
DataCell c1 = row.getCell(sortIndex);
DataCell c2 = row.getCell(actIndex);
if (k++ % 100 == 0) {
sexec.checkCanceled();
sexec.setProgress(k / rowCount);
}
if (c1.isMissing()) {
continue;
} else {
curve[maxK] = new Helper(((DoubleValue) c1).getDoubleValue(), c2);
}
maxK++;
}
Arrays.sort(curve, 0, maxK);
if (c.isSortDescending()) {
for (int j = 0; j < maxK / 2; j++) {
Helper h = curve[j];
curve[j] = curve[maxK - j - 1];
curve[maxK - j - 1] = h;
}
}
// this is for down-sampling so that the view is faster;
// plotting >100,000 points takes quite a long time
final int size = Math.min(MAX_RESOLUTION, maxK);
final double downSampleRate = maxK / (double) size;
final double[] xValues = new double[size + 1];
final double[] yValues = new double[size + 1];
xValues[0] = 0;
yValues[0] = 0;
int lastK = 0;
double y = 0, area = 0;
int nextHitRatePoint = 0;
final double[] hitRateValues = new double[fractionSizes.length];
final HashMap<DataCell, MutableInteger> clusters = new HashMap<DataCell, MutableInteger>();
// set hit rate values for fractions that are smaller than 1 row to 0
while ((maxK * fractionSizes[nextHitRatePoint] / 100) < 1) {
hitRateValues[nextHitRatePoint++] = 0;
}
for (k = 1; k <= maxK; k++) {
final Helper h = curve[k - 1];
if (m_settings.plotMode() == PlotMode.PlotSum) {
y += ((DoubleValue) h.b).getDoubleValue();
} else if (m_settings.plotMode() == PlotMode.PlotHits) {
if (!h.b.isMissing() && (((DoubleValue) h.b).getDoubleValue() >= m_settings.hitThreshold())) {
y++;
}
} else if (!h.b.isMissing()) {
MutableInteger count = clusters.get(h.b);
if (count == null) {
count = new MutableInteger(0);
clusters.put(h.b, count);
}
if (count.inc() == m_settings.minClusterMembers()) {
y++;
}
}
area += y / maxK;
if ((int) (k / downSampleRate) >= lastK + 1) {
lastK++;
xValues[lastK] = k;
yValues[lastK] = y;
}
// thats why this needs to be a while
while ((nextHitRatePoint < fractionSizes.length) && (k == (int) Math.floor(maxK * fractionSizes[nextHitRatePoint] / 100))) {
hitRateValues[nextHitRatePoint] = y;
nextHitRatePoint++;
}
}
xValues[xValues.length - 1] = maxK;
yValues[yValues.length - 1] = y;
area /= y;
m_curves.add(new EnrichmentPlot(c.getSortColumn() + " vs " + c.getActivityColumn(), xValues, yValues, area));
areaOutCont.addRowToTable(new DefaultRow(new RowKey(c.toString()), new DoubleCell(area)));
for (int j = 0; j < hitRateValues.length; j++) {
hitRateValues[j] /= y;
}
double[] enrichmentFactors = new double[hitRateValues.length];
for (int j = 0; j < enrichmentFactors.length; j++) {
enrichmentFactors[j] = calculateEnrichmentFactor(hitRateValues[j], fractionSizes[j]);
}
discrateOutCont.addRowToTable(new DefaultRow(new RowKey(c.toString()), ArrayUtils.addAll(hitRateValues, enrichmentFactors)));
}
areaOutCont.close();
discrateOutCont.close();
return new BufferedDataTable[] { areaOutCont.getTable(), discrateOutCont.getTable() };
}
Aggregations