use of org.knime.base.node.mine.treeensemble2.sample.row.RowSample in project knime-core by knime.
the class TreeNumericColumnData method calcBestSplitRegression.
@Override
public SplitCandidate calcBestSplitRegression(final DataMemberships dataMemberships, final RegressionPriors targetPriors, final TreeTargetNumericColumnData targetColumn, final RandomData rd) {
final TreeEnsembleLearnerConfiguration config = getConfiguration();
final boolean useAverageSplitPoints = config.isUseAverageSplitPoints();
final int minChildNodeSize = config.getMinChildSize();
// get columnMemberships
final ColumnMemberships columnMemberships = dataMemberships.getColumnMemberships(getMetaData().getAttributeIndex());
final int lengthNonMissing = getLengthNonMissing();
// missing value handling
final boolean useXGBoostMissingValueHandling = config.getMissingValueHandling() == MissingValueHandling.XGBoost;
// are there missing values in this column (complete column)
boolean branchContainsMissingValues = containsMissingValues();
boolean missingsGoLeft = true;
double missingWeight = 0.0;
double missingY = 0.0;
// check if there are missing values in this rowsample
if (branchContainsMissingValues) {
columnMemberships.goToLast();
while (columnMemberships.getIndexInColumn() >= lengthNonMissing) {
missingWeight += columnMemberships.getRowWeight();
missingY += targetColumn.getValueFor(columnMemberships.getOriginalIndex());
if (!columnMemberships.previous()) {
break;
}
}
columnMemberships.reset();
branchContainsMissingValues = missingWeight > 0.0;
}
final double ySumTotal = targetPriors.getYSum() - missingY;
final double nrRecordsTotal = targetPriors.getNrRecords() - missingWeight;
final double criterionTotal = useXGBoostMissingValueHandling ? (ySumTotal + missingY) * (ySumTotal + missingY) / (nrRecordsTotal + missingWeight) : ySumTotal * ySumTotal / nrRecordsTotal;
double ySumLeft = 0.0;
double nrRecordsLeft = 0.0;
double ySumRight = ySumTotal;
double nrRecordsRight = nrRecordsTotal;
// all values in the current branch are missing
if (nrRecordsRight == 0) {
// it is impossible to determine a split
return null;
}
double bestSplit = Double.NEGATIVE_INFINITY;
double bestImprovement = 0.0;
double lastSeenY = Double.NaN;
double lastSeenValue = Double.NEGATIVE_INFINITY;
double lastSeenWeight = -1.0;
// compute the gain, keep the one that maximizes the split
while (columnMemberships.next()) {
final double weight = columnMemberships.getRowWeight();
if (weight < EPSILON) {
// ignore record: not in current branch or not in sample
continue;
} else if (Math.floor(weight) != weight) {
throw new UnsupportedOperationException("weighted records (missing values?) not supported, " + "weight is " + weight);
}
final double value = getSorted(columnMemberships.getIndexInColumn());
if (lastSeenWeight > 0.0) {
ySumLeft += lastSeenWeight * lastSeenY;
ySumRight -= lastSeenWeight * lastSeenY;
nrRecordsLeft += lastSeenWeight;
nrRecordsRight -= lastSeenWeight;
if (nrRecordsLeft >= minChildNodeSize && nrRecordsRight >= minChildNodeSize && lastSeenValue < value) {
boolean tempMissingsGoLeft = true;
double childrenSquaredSum;
if (branchContainsMissingValues && useXGBoostMissingValueHandling) {
final double[] tempChildrenSquaredSum = new double[2];
tempChildrenSquaredSum[0] = ((ySumLeft + missingY) * (ySumLeft + missingY) / (nrRecordsLeft + missingWeight)) + (ySumRight * ySumRight / nrRecordsRight);
tempChildrenSquaredSum[1] = (ySumLeft * ySumLeft / nrRecordsLeft) + ((ySumRight + missingY) * (ySumRight + missingY) / (nrRecordsRight + missingWeight));
if (tempChildrenSquaredSum[0] >= tempChildrenSquaredSum[1]) {
childrenSquaredSum = tempChildrenSquaredSum[0];
tempMissingsGoLeft = true;
} else {
childrenSquaredSum = tempChildrenSquaredSum[1];
tempMissingsGoLeft = false;
}
} else {
childrenSquaredSum = (ySumLeft * ySumLeft / nrRecordsLeft) + (ySumRight * ySumRight / nrRecordsRight);
}
double criterion = childrenSquaredSum - criterionTotal;
boolean randomTieBreaker = criterion == bestImprovement ? rd.nextInt(0, 1) == 1 : false;
if (criterion > bestImprovement || randomTieBreaker) {
bestImprovement = criterion;
bestSplit = useAverageSplitPoints ? getCenter(lastSeenValue, value) : lastSeenValue;
// if there are no missing values go with majority
missingsGoLeft = branchContainsMissingValues ? tempMissingsGoLeft : nrRecordsLeft >= nrRecordsRight;
}
}
}
lastSeenY = targetColumn.getValueFor(columnMemberships.getOriginalIndex());
lastSeenValue = value;
lastSeenWeight = weight;
}
// + " but was " + lastSeenY * lastSeenWeight;
if (bestImprovement > 0.0) {
if (useXGBoostMissingValueHandling) {
// return new NumericMissingSplitCandidate(this, bestSplit, bestImprovement, missingsGoLeft);
return new NumericSplitCandidate(this, bestSplit, bestImprovement, new BitSet(), missingsGoLeft ? NumericSplitCandidate.MISSINGS_GO_LEFT : NumericSplitCandidate.MISSINGS_GO_RIGHT);
}
return new NumericSplitCandidate(this, bestSplit, bestImprovement, getMissedRows(columnMemberships), NumericSplitCandidate.NO_MISSINGS);
} else {
return null;
}
}
use of org.knime.base.node.mine.treeensemble2.sample.row.RowSample in project knime-core by knime.
the class RegressionTreeLearnerNodeModel method execute.
/**
* {@inheritDoc}
*/
@Override
protected PortObject[] execute(final PortObject[] inObjects, final ExecutionContext exec) throws Exception {
BufferedDataTable t = (BufferedDataTable) inObjects[0];
DataTableSpec spec = t.getDataTableSpec();
final FilterLearnColumnRearranger learnRearranger = m_configuration.filterLearnColumns(spec);
String warn = learnRearranger.getWarning();
BufferedDataTable learnTable = exec.createColumnRearrangeTable(t, learnRearranger, exec.createSubProgress(0.0));
DataTableSpec learnSpec = learnTable.getDataTableSpec();
ExecutionMonitor readInExec = exec.createSubProgress(0.1);
ExecutionMonitor learnExec = exec.createSubProgress(0.9);
TreeDataCreator dataCreator = new TreeDataCreator(m_configuration, learnSpec, learnTable.getRowCount());
exec.setProgress("Reading data into memory");
TreeData data = dataCreator.readData(learnTable, m_configuration, readInExec);
m_hiliteRowSample = dataCreator.getDataRowsForHilite();
m_viewMessage = dataCreator.getViewMessage();
String dataCreationWarning = dataCreator.getAndClearWarningMessage();
if (dataCreationWarning != null) {
if (warn == null) {
warn = dataCreationWarning;
} else {
warn = warn + "\n" + dataCreationWarning;
}
}
readInExec.setProgress(1.0);
exec.setMessage("Learning tree");
RandomData rd = m_configuration.createRandomData();
final IDataIndexManager indexManager;
if (data.getTreeType() == TreeType.BitVector) {
indexManager = new BitVectorDataIndexManager(data.getNrRows());
} else {
indexManager = new DefaultDataIndexManager(data);
}
TreeNodeSignatureFactory signatureFactory = null;
int maxLevels = m_configuration.getMaxLevels();
if (maxLevels < TreeEnsembleLearnerConfiguration.MAX_LEVEL_INFINITE) {
int capacity = IntMath.pow(2, maxLevels - 1);
signatureFactory = new TreeNodeSignatureFactory(capacity);
} else {
signatureFactory = new TreeNodeSignatureFactory();
}
final RowSample rowSample = m_configuration.createRowSampler(data).createRowSample(rd);
TreeLearnerRegression treeLearner = new TreeLearnerRegression(m_configuration, data, indexManager, signatureFactory, rd, rowSample);
TreeModelRegression regTree = treeLearner.learnSingleTree(learnExec, rd);
RegressionTreeModel model = new RegressionTreeModel(m_configuration, data.getMetaData(), regTree, data.getTreeType());
RegressionTreeModelPortObjectSpec treePortObjectSpec = new RegressionTreeModelPortObjectSpec(learnSpec);
RegressionTreeModelPortObject treePortObject = new RegressionTreeModelPortObject(model, treePortObjectSpec);
learnExec.setProgress(1.0);
m_treeModelPortObject = treePortObject;
if (warn != null) {
setWarningMessage(warn);
}
return new PortObject[] { treePortObject };
}
use of org.knime.base.node.mine.treeensemble2.sample.row.RowSample in project knime-core by knime.
the class TreeEnsembleLearner method learnEnsemble.
public TreeEnsembleModel learnEnsemble(final ExecutionMonitor exec) throws CanceledExecutionException, ExecutionException {
final int nrModels = m_config.getNrModels();
final RandomData rd = m_config.createRandomData();
final ThreadPool tp = KNIMEConstants.GLOBAL_THREAD_POOL;
final AtomicReference<Throwable> learnThrowableRef = new AtomicReference<Throwable>();
@SuppressWarnings("unchecked") final Future<TreeLearnerResult>[] modelFutures = new Future[nrModels];
final int procCount = 3 * Runtime.getRuntime().availableProcessors() / 2;
final Semaphore semaphore = new Semaphore(procCount);
Callable<TreeLearnerResult[]> learnCallable = new Callable<TreeLearnerResult[]>() {
@Override
public TreeLearnerResult[] call() throws Exception {
final TreeLearnerResult[] results = new TreeLearnerResult[nrModels];
for (int i = 0; i < nrModels; i++) {
semaphore.acquire();
finishedTree(i - procCount, exec);
checkThrowable(learnThrowableRef);
RandomData rdSingle = TreeEnsembleLearnerConfiguration.createRandomData(rd.nextLong(Long.MIN_VALUE, Long.MAX_VALUE));
ExecutionMonitor subExec = exec.createSubProgress(0.0);
modelFutures[i] = tp.enqueue(new TreeLearnerCallable(subExec, rdSingle, learnThrowableRef, semaphore));
}
for (int i = 0; i < procCount; i++) {
semaphore.acquire();
finishedTree(nrModels - 1 + i - procCount, exec);
}
for (int i = 0; i < nrModels; i++) {
try {
results[i] = modelFutures[i].get();
} catch (Exception e) {
learnThrowableRef.compareAndSet(null, e);
}
}
return results;
}
private void finishedTree(final int treeIndex, final ExecutionMonitor progMon) {
if (treeIndex > 0) {
progMon.setProgress(treeIndex / (double) nrModels, "Tree " + treeIndex + "/" + nrModels);
}
}
};
TreeLearnerResult[] modelResults = tp.runInvisible(learnCallable);
checkThrowable(learnThrowableRef);
AbstractTreeModel[] models = new AbstractTreeModel[nrModels];
m_rowSamples = new RowSample[nrModels];
m_columnSampleStrategies = new ColumnSampleStrategy[nrModels];
for (int i = 0; i < nrModels; i++) {
models[i] = modelResults[i].m_treeModel;
m_rowSamples[i] = modelResults[i].m_rowSample;
m_columnSampleStrategies[i] = modelResults[i].m_rootColumnSampleStrategy;
}
m_ensembleModel = new TreeEnsembleModel(m_config, m_data.getMetaData(), models, m_data.getTreeType());
return m_ensembleModel;
}
Aggregations