use of org.knime.base.node.mine.treeensemble2.learner.TreeLearnerRegression in project knime-core by knime.
the class MGradientBoostedTreesLearner method learn.
/**
* {@inheritDoc}
*/
@Override
public AbstractGradientBoostingModel learn(final ExecutionMonitor exec) throws CanceledExecutionException {
final TreeData actualData = getData();
final GradientBoostingLearnerConfiguration config = getConfig();
final int nrModels = config.getNrModels();
final TreeTargetNumericColumnData actualTarget = getTarget();
final double initialValue = actualTarget.getMedian();
final ArrayList<TreeModelRegression> models = new ArrayList<TreeModelRegression>(nrModels);
final ArrayList<Map<TreeNodeSignature, Double>> coefficientMaps = new ArrayList<Map<TreeNodeSignature, Double>>(nrModels);
final double[] previousPrediction = new double[actualTarget.getNrRows()];
Arrays.fill(previousPrediction, initialValue);
final RandomData rd = config.createRandomData();
final double alpha = config.getAlpha();
TreeNodeSignatureFactory signatureFactory = null;
final int maxLevels = config.getMaxLevels();
// this should be the default
if (maxLevels < TreeEnsembleLearnerConfiguration.MAX_LEVEL_INFINITE) {
final int capacity = IntMath.pow(2, maxLevels - 1);
signatureFactory = new TreeNodeSignatureFactory(capacity);
} else {
signatureFactory = new TreeNodeSignatureFactory();
}
exec.setMessage("Learning model");
TreeData residualData;
for (int i = 0; i < nrModels; i++) {
final double[] residuals = new double[actualTarget.getNrRows()];
for (int j = 0; j < actualTarget.getNrRows(); j++) {
residuals[j] = actualTarget.getValueFor(j) - previousPrediction[j];
}
final double quantile = calculateAlphaQuantile(residuals, alpha);
final double[] gradients = new double[residuals.length];
for (int j = 0; j < gradients.length; j++) {
gradients[j] = Math.abs(residuals[j]) <= quantile ? residuals[j] : quantile * Math.signum(residuals[j]);
}
residualData = createResidualDataFromArray(gradients, actualData);
final RandomData rdSingle = TreeEnsembleLearnerConfiguration.createRandomData(rd.nextLong(Long.MIN_VALUE, Long.MAX_VALUE));
final RowSample rowSample = getRowSampler().createRowSample(rdSingle);
final TreeLearnerRegression treeLearner = new TreeLearnerRegression(getConfig(), residualData, getIndexManager(), signatureFactory, rdSingle, rowSample);
final TreeModelRegression tree = treeLearner.learnSingleTree(exec, rdSingle);
final Map<TreeNodeSignature, Double> coefficientMap = calcCoefficientMap(residuals, quantile, tree);
adaptPreviousPrediction(previousPrediction, tree, coefficientMap);
models.add(tree);
coefficientMaps.add(coefficientMap);
exec.setProgress(((double) i) / nrModels, "Finished level " + i + "/" + nrModels);
}
return new GradientBoostedTreesModel(getConfig(), actualData.getMetaData(), models.toArray(new TreeModelRegression[models.size()]), actualData.getTreeType(), initialValue, coefficientMaps);
}
use of org.knime.base.node.mine.treeensemble2.learner.TreeLearnerRegression in project knime-core by knime.
the class RegressionTreeLearnerNodeModel method execute.
/**
* {@inheritDoc}
*/
@Override
protected PortObject[] execute(final PortObject[] inObjects, final ExecutionContext exec) throws Exception {
BufferedDataTable t = (BufferedDataTable) inObjects[0];
DataTableSpec spec = t.getDataTableSpec();
final FilterLearnColumnRearranger learnRearranger = m_configuration.filterLearnColumns(spec);
String warn = learnRearranger.getWarning();
BufferedDataTable learnTable = exec.createColumnRearrangeTable(t, learnRearranger, exec.createSubProgress(0.0));
DataTableSpec learnSpec = learnTable.getDataTableSpec();
ExecutionMonitor readInExec = exec.createSubProgress(0.1);
ExecutionMonitor learnExec = exec.createSubProgress(0.9);
TreeDataCreator dataCreator = new TreeDataCreator(m_configuration, learnSpec, learnTable.getRowCount());
exec.setProgress("Reading data into memory");
TreeData data = dataCreator.readData(learnTable, m_configuration, readInExec);
m_hiliteRowSample = dataCreator.getDataRowsForHilite();
m_viewMessage = dataCreator.getViewMessage();
String dataCreationWarning = dataCreator.getAndClearWarningMessage();
if (dataCreationWarning != null) {
if (warn == null) {
warn = dataCreationWarning;
} else {
warn = warn + "\n" + dataCreationWarning;
}
}
readInExec.setProgress(1.0);
exec.setMessage("Learning tree");
RandomData rd = m_configuration.createRandomData();
final IDataIndexManager indexManager;
if (data.getTreeType() == TreeType.BitVector) {
indexManager = new BitVectorDataIndexManager(data.getNrRows());
} else {
indexManager = new DefaultDataIndexManager(data);
}
TreeNodeSignatureFactory signatureFactory = null;
int maxLevels = m_configuration.getMaxLevels();
if (maxLevels < TreeEnsembleLearnerConfiguration.MAX_LEVEL_INFINITE) {
int capacity = IntMath.pow(2, maxLevels - 1);
signatureFactory = new TreeNodeSignatureFactory(capacity);
} else {
signatureFactory = new TreeNodeSignatureFactory();
}
final RowSample rowSample = m_configuration.createRowSampler(data).createRowSample(rd);
TreeLearnerRegression treeLearner = new TreeLearnerRegression(m_configuration, data, indexManager, signatureFactory, rd, rowSample);
TreeModelRegression regTree = treeLearner.learnSingleTree(learnExec, rd);
RegressionTreeModel model = new RegressionTreeModel(m_configuration, data.getMetaData(), regTree, data.getTreeType());
RegressionTreeModelPortObjectSpec treePortObjectSpec = new RegressionTreeModelPortObjectSpec(learnSpec);
RegressionTreeModelPortObject treePortObject = new RegressionTreeModelPortObject(model, treePortObjectSpec);
learnExec.setProgress(1.0);
m_treeModelPortObject = treePortObject;
if (warn != null) {
setWarningMessage(warn);
}
return new PortObject[] { treePortObject };
}
Aggregations