use of org.knime.base.node.mine.treeensemble2.learner.BitSplitCandidate in project knime-core by knime.
the class TreeBitVectorColumnData method calcBestSplitRegression.
/**
* {@inheritDoc}
*/
@Override
public SplitCandidate calcBestSplitRegression(final DataMemberships dataMemberships, final RegressionPriors targetPriors, final TreeTargetNumericColumnData targetColumn, final RandomData rd) {
final double ySumTotal = targetPriors.getYSum();
final double nrRecordsTotal = targetPriors.getNrRecords();
final double criterionTotal = ySumTotal * ySumTotal / nrRecordsTotal;
final int minChildSize = getConfiguration().getMinChildSize();
final ColumnMemberships columnMemberships = dataMemberships.getColumnMemberships(getMetaData().getAttributeIndex());
double onWeights = 0.0;
double offWeights = 0.0;
double ySumOn = 0.0;
double ySumOff = 0.0;
while (columnMemberships.next()) {
final double weight = columnMemberships.getRowWeight();
if (weight < EPSILON) {
// ignore record: not in current branch or not in sample
} else {
final double y = targetColumn.getValueFor(columnMemberships.getOriginalIndex());
if (m_columnBitSet.get(columnMemberships.getIndexInColumn())) {
onWeights += weight;
ySumOn += weight * y;
} else {
offWeights += weight;
ySumOff += weight * y;
}
}
}
if (onWeights < minChildSize || offWeights < minChildSize) {
return null;
}
final double onCriterion = ySumOn * ySumOn / onWeights;
final double offCriterion = ySumOff * ySumOff / offWeights;
final double gain = onCriterion + offCriterion - criterionTotal;
if (gain > 0) {
return new BitSplitCandidate(this, gain);
}
return null;
}
use of org.knime.base.node.mine.treeensemble2.learner.BitSplitCandidate in project knime-core by knime.
the class TreeBitVectorColumnData method calcBestSplitClassification.
/**
* {@inheritDoc}
*/
@Override
public SplitCandidate calcBestSplitClassification(final DataMemberships dataMemberships, final ClassificationPriors targetPriors, final TreeTargetNominalColumnData targetColumn, final RandomData rd) {
final NominalValueRepresentation[] targetVals = targetColumn.getMetaData().getValues();
final IImpurity impurityCriterion = targetPriors.getImpurityCriterion();
final int minChildSize = getConfiguration().getMinChildSize();
// distribution of target for On ('1') and Off ('0') bits
final double[] onTargetWeights = new double[targetVals.length];
final double[] offTargetWeights = new double[targetVals.length];
double onWeights = 0.0;
double offWeights = 0.0;
final ColumnMemberships columnMemberships = dataMemberships.getColumnMemberships(getMetaData().getAttributeIndex());
while (columnMemberships.next()) {
final double weight = columnMemberships.getRowWeight();
if (weight < EPSILON) {
// ignore record: not in current branch or not in sample
assert false : "This code should never be reached!";
} else {
final int target = targetColumn.getValueFor(columnMemberships.getOriginalIndex());
if (m_columnBitSet.get(columnMemberships.getIndexInColumn())) {
onWeights += weight;
onTargetWeights[target] += weight;
} else {
offWeights += weight;
offTargetWeights[target] += weight;
}
}
}
if (onWeights < minChildSize || offWeights < minChildSize) {
return null;
}
final double weightSum = onWeights + offWeights;
final double onImpurity = impurityCriterion.getPartitionImpurity(onTargetWeights, onWeights);
final double offImpurity = impurityCriterion.getPartitionImpurity(offTargetWeights, offWeights);
final double[] partitionWeights = new double[] { onWeights, offWeights };
final double postSplitImpurity = impurityCriterion.getPostSplitImpurity(new double[] { onImpurity, offImpurity }, partitionWeights, weightSum);
final double gainValue = impurityCriterion.getGain(targetPriors.getPriorImpurity(), postSplitImpurity, partitionWeights, weightSum);
return new BitSplitCandidate(this, gainValue);
}
Aggregations