use of org.knime.base.node.mine.treeensemble2.sample.row.RowSample in project knime-core by knime.
the class TreeLearnerRegression method learnSingleTree.
/**
* {@inheritDoc}
*/
@Override
public TreeModelRegression learnSingleTree(final ExecutionMonitor exec, final RandomData rd) throws CanceledExecutionException {
final TreeTargetNumericColumnData targetColumn = getTargetData();
final TreeData data = getData();
final RowSample rowSampling = getRowSampling();
final TreeEnsembleLearnerConfiguration config = getConfig();
final IDataIndexManager indexManager = getIndexManager();
DataMemberships rootDataMemberships = new RootDataMemberships(rowSampling, data, indexManager);
RegressionPriors targetPriors = targetColumn.getPriors(rootDataMemberships, config);
BitSet forbiddenColumnSet = new BitSet(data.getNrAttributes());
boolean isGradientBoosting = config instanceof GradientBoostingLearnerConfiguration;
if (isGradientBoosting) {
m_leafs = new ArrayList<TreeNodeRegression>();
}
final TreeNodeSignature rootSignature = TreeNodeSignature.ROOT_SIGNATURE;
final ColumnSample rootColumnSample = getColSamplingStrategy().getColumnSampleForTreeNode(rootSignature);
TreeNodeRegression rootNode = buildTreeNode(exec, 0, rootDataMemberships, rootColumnSample, getSignatureFactory().getRootSignature(), targetPriors, forbiddenColumnSet);
assert forbiddenColumnSet.cardinality() == 0;
rootNode.setTreeNodeCondition(TreeNodeTrueCondition.INSTANCE);
if (isGradientBoosting) {
return new TreeModelRegression(rootNode, m_leafs);
}
return new TreeModelRegression(rootNode);
}
use of org.knime.base.node.mine.treeensemble2.sample.row.RowSample in project knime-core by knime.
the class TreeLearnerClassification method learnSingleTreeRecursive.
private TreeModelClassification learnSingleTreeRecursive(final ExecutionMonitor exec, final RandomData rd) throws CanceledExecutionException {
final TreeData data = getData();
final RowSample rowSampling = getRowSampling();
final TreeEnsembleLearnerConfiguration config = getConfig();
final TreeTargetNominalColumnData targetColumn = (TreeTargetNominalColumnData) data.getTargetColumn();
final // new RootDataMem(rowSampling, getIndexManager());
DataMemberships rootDataMemberships = new RootDataMemberships(rowSampling, data, getIndexManager());
ClassificationPriors targetPriors = targetColumn.getDistribution(rootDataMemberships, config);
BitSet forbiddenColumnSet = new BitSet(data.getNrAttributes());
// final DataMemberships rootDataMemberships = new IntArrayDataMemberships(sampleWeights, data);
final TreeNodeSignature rootSignature = TreeNodeSignature.ROOT_SIGNATURE;
final ColumnSample rootColumnSample = getColSamplingStrategy().getColumnSampleForTreeNode(rootSignature);
TreeNodeClassification rootNode = null;
rootNode = buildTreeNode(exec, 0, rootDataMemberships, rootColumnSample, rootSignature, targetPriors, forbiddenColumnSet);
assert forbiddenColumnSet.cardinality() == 0;
rootNode.setTreeNodeCondition(TreeNodeTrueCondition.INSTANCE);
return new TreeModelClassification(rootNode);
}
use of org.knime.base.node.mine.treeensemble2.sample.row.RowSample in project knime-core by knime.
the class MGradientBoostedTreesLearner method learn.
/**
* {@inheritDoc}
*/
@Override
public AbstractGradientBoostingModel learn(final ExecutionMonitor exec) throws CanceledExecutionException {
final TreeData actualData = getData();
final GradientBoostingLearnerConfiguration config = getConfig();
final int nrModels = config.getNrModels();
final TreeTargetNumericColumnData actualTarget = getTarget();
final double initialValue = actualTarget.getMedian();
final ArrayList<TreeModelRegression> models = new ArrayList<TreeModelRegression>(nrModels);
final ArrayList<Map<TreeNodeSignature, Double>> coefficientMaps = new ArrayList<Map<TreeNodeSignature, Double>>(nrModels);
final double[] previousPrediction = new double[actualTarget.getNrRows()];
Arrays.fill(previousPrediction, initialValue);
final RandomData rd = config.createRandomData();
final double alpha = config.getAlpha();
TreeNodeSignatureFactory signatureFactory = null;
final int maxLevels = config.getMaxLevels();
// this should be the default
if (maxLevels < TreeEnsembleLearnerConfiguration.MAX_LEVEL_INFINITE) {
final int capacity = IntMath.pow(2, maxLevels - 1);
signatureFactory = new TreeNodeSignatureFactory(capacity);
} else {
signatureFactory = new TreeNodeSignatureFactory();
}
exec.setMessage("Learning model");
TreeData residualData;
for (int i = 0; i < nrModels; i++) {
final double[] residuals = new double[actualTarget.getNrRows()];
for (int j = 0; j < actualTarget.getNrRows(); j++) {
residuals[j] = actualTarget.getValueFor(j) - previousPrediction[j];
}
final double quantile = calculateAlphaQuantile(residuals, alpha);
final double[] gradients = new double[residuals.length];
for (int j = 0; j < gradients.length; j++) {
gradients[j] = Math.abs(residuals[j]) <= quantile ? residuals[j] : quantile * Math.signum(residuals[j]);
}
residualData = createResidualDataFromArray(gradients, actualData);
final RandomData rdSingle = TreeEnsembleLearnerConfiguration.createRandomData(rd.nextLong(Long.MIN_VALUE, Long.MAX_VALUE));
final RowSample rowSample = getRowSampler().createRowSample(rdSingle);
final TreeLearnerRegression treeLearner = new TreeLearnerRegression(getConfig(), residualData, getIndexManager(), signatureFactory, rdSingle, rowSample);
final TreeModelRegression tree = treeLearner.learnSingleTree(exec, rdSingle);
final Map<TreeNodeSignature, Double> coefficientMap = calcCoefficientMap(residuals, quantile, tree);
adaptPreviousPrediction(previousPrediction, tree, coefficientMap);
models.add(tree);
coefficientMaps.add(coefficientMap);
exec.setProgress(((double) i) / nrModels, "Finished level " + i + "/" + nrModels);
}
return new GradientBoostedTreesModel(getConfig(), actualData.getMetaData(), models.toArray(new TreeModelRegression[models.size()]), actualData.getTreeType(), initialValue, coefficientMaps);
}
use of org.knime.base.node.mine.treeensemble2.sample.row.RowSample in project knime-core by knime.
the class RootDescendantDataMembershipsTest method testCreateChildDataMemberships.
@Test
public void testCreateChildDataMemberships() {
TreeEnsembleLearnerConfiguration config = new TreeEnsembleLearnerConfiguration(false);
TestDataGenerator dataGen = new TestDataGenerator(config);
TreeData data = dataGen.createTennisData();
DefaultDataIndexManager indexManager = new DefaultDataIndexManager(data);
int nrRows = data.getNrRows();
RowSample rowSample = new DefaultRowSample(nrRows);
RootDataMemberships rootMemberships = new RootDataMemberships(rowSample, data, indexManager);
BitSet firstHalf = new BitSet(nrRows);
firstHalf.set(0, nrRows / 2);
DataMemberships firstHalfChildMemberships = rootMemberships.createChildMemberships(firstHalf);
assertThat(firstHalfChildMemberships, instanceOf(BitSetDescendantDataMemberships.class));
BitSetDescendantDataMemberships bitSetFirstHalfChildMemberships = (BitSetDescendantDataMemberships) firstHalfChildMemberships;
assertEquals(firstHalf, bitSetFirstHalfChildMemberships.getBitSet());
BitSet firstQuarter = new BitSet(nrRows);
firstQuarter.set(0, nrRows / 4);
DataMemberships firstQuarterGrandChild = firstHalfChildMemberships.createChildMemberships(firstQuarter);
assertThat(firstQuarterGrandChild, instanceOf(BitSetDescendantDataMemberships.class));
BitSetDescendantDataMemberships bitSetFirstQuarterGrandChild = (BitSetDescendantDataMemberships) firstQuarterGrandChild;
assertEquals(firstQuarter, bitSetFirstQuarterGrandChild.getBitSet());
}
use of org.knime.base.node.mine.treeensemble2.sample.row.RowSample in project knime-core by knime.
the class RootDescendantDataMembershipsTest method testGetColumnMemberships.
@Test
public void testGetColumnMemberships() {
TreeEnsembleLearnerConfiguration config = new TreeEnsembleLearnerConfiguration(false);
TestDataGenerator dataGen = new TestDataGenerator(config);
TreeData data = dataGen.createTennisData();
DefaultDataIndexManager indexManager = new DefaultDataIndexManager(data);
int nrRows = data.getNrRows();
RowSample rowSample = new DefaultRowSample(nrRows);
RootDataMemberships rootMemberships = new RootDataMemberships(rowSample, data, indexManager);
ColumnMemberships rootColMem = rootMemberships.getColumnMemberships(0);
assertThat(rootColMem, instanceOf(IntArrayColumnMemberships.class));
assertEquals(nrRows, rootColMem.size());
int[] expectedOriginalIndices = new int[] { 0, 1, 7, 8, 10, 2, 6, 11, 12, 3, 4, 5, 9, 13 };
for (int i = 0; rootColMem.next(); i++) {
// in this case originalIndex and indexInDataMemberships are the same
assertEquals(expectedOriginalIndices[i], rootColMem.getIndexInDataMemberships());
assertEquals(expectedOriginalIndices[i], rootColMem.getIndexInDataMemberships());
assertEquals(i, rootColMem.getIndexInColumn());
}
BitSet lastHalf = new BitSet(nrRows);
lastHalf.set(nrRows / 2, nrRows);
DataMemberships lastHalfChild = rootMemberships.createChildMemberships(lastHalf);
ColumnMemberships childColMem = lastHalfChild.getColumnMemberships(0);
assertThat(childColMem, instanceOf(DescendantColumnMemberships.class));
assertEquals(nrRows / 2, childColMem.size());
expectedOriginalIndices = new int[] { 7, 8, 10, 11, 12, 9, 13 };
int[] expectedIndexInColumn = new int[] { 2, 3, 4, 7, 8, 12, 13 };
int[] expectedIndexInDataMemberships = new int[] { 7, 8, 10, 11, 12, 9, 13 };
for (int i = 0; childColMem.next(); i++) {
assertEquals(expectedOriginalIndices[i], childColMem.getOriginalIndex());
assertEquals(expectedIndexInColumn[i], childColMem.getIndexInColumn());
assertEquals(expectedIndexInDataMemberships[i], childColMem.getIndexInDataMemberships());
}
}
Aggregations