use of org.knime.base.node.mine.treeensemble2.model.GradientBoostedTreesModel in project knime-core by knime.
the class RegressionGBTModelImporter method importFromPMMLInternal.
/**
* {@inheritDoc}
*/
@Override
public GradientBoostedTreesModel importFromPMMLInternal(final MiningModel miningModel) {
Segmentation segmentation = miningModel.getSegmentation();
CheckUtils.checkArgument(segmentation.getMultipleModelMethod() == MULTIPLEMODELMETHOD.SUM, "The provided segmentation has not the required sum as multiple model method but '%s' instead.", segmentation.getMultipleModelMethod());
Pair<List<TreeModelRegression>, List<Map<TreeNodeSignature, Double>>> treesCoeffientMapsPair = readSumSegmentation(segmentation);
List<TreeModelRegression> trees = treesCoeffientMapsPair.getFirst();
// TODO user should be warned if there is no initial value or anything else is fishy
double initialValue = miningModel.getTargets().getTargetList().get(0).getRescaleConstant();
// currently only models learned on "ordinary" columns can be read back in
return new GradientBoostedTreesModel(getMetaDataMapper().getTreeMetaData(), trees.toArray(new TreeModelRegression[trees.size()]), TreeType.Ordinary, initialValue, treesCoeffientMapsPair.getSecond());
}
use of org.knime.base.node.mine.treeensemble2.model.GradientBoostedTreesModel in project knime-core by knime.
the class MGradientBoostedTreesLearner method learn.
/**
* {@inheritDoc}
*/
@Override
public AbstractGradientBoostingModel learn(final ExecutionMonitor exec) throws CanceledExecutionException {
final TreeData actualData = getData();
final GradientBoostingLearnerConfiguration config = getConfig();
final int nrModels = config.getNrModels();
final TreeTargetNumericColumnData actualTarget = getTarget();
final double initialValue = actualTarget.getMedian();
final ArrayList<TreeModelRegression> models = new ArrayList<TreeModelRegression>(nrModels);
final ArrayList<Map<TreeNodeSignature, Double>> coefficientMaps = new ArrayList<Map<TreeNodeSignature, Double>>(nrModels);
final double[] previousPrediction = new double[actualTarget.getNrRows()];
Arrays.fill(previousPrediction, initialValue);
final RandomData rd = config.createRandomData();
final double alpha = config.getAlpha();
TreeNodeSignatureFactory signatureFactory = null;
final int maxLevels = config.getMaxLevels();
// this should be the default
if (maxLevels < TreeEnsembleLearnerConfiguration.MAX_LEVEL_INFINITE) {
final int capacity = IntMath.pow(2, maxLevels - 1);
signatureFactory = new TreeNodeSignatureFactory(capacity);
} else {
signatureFactory = new TreeNodeSignatureFactory();
}
exec.setMessage("Learning model");
TreeData residualData;
for (int i = 0; i < nrModels; i++) {
final double[] residuals = new double[actualTarget.getNrRows()];
for (int j = 0; j < actualTarget.getNrRows(); j++) {
residuals[j] = actualTarget.getValueFor(j) - previousPrediction[j];
}
final double quantile = calculateAlphaQuantile(residuals, alpha);
final double[] gradients = new double[residuals.length];
for (int j = 0; j < gradients.length; j++) {
gradients[j] = Math.abs(residuals[j]) <= quantile ? residuals[j] : quantile * Math.signum(residuals[j]);
}
residualData = createResidualDataFromArray(gradients, actualData);
final RandomData rdSingle = TreeEnsembleLearnerConfiguration.createRandomData(rd.nextLong(Long.MIN_VALUE, Long.MAX_VALUE));
final RowSample rowSample = getRowSampler().createRowSample(rdSingle);
final TreeLearnerRegression treeLearner = new TreeLearnerRegression(getConfig(), residualData, getIndexManager(), signatureFactory, rdSingle, rowSample);
final TreeModelRegression tree = treeLearner.learnSingleTree(exec, rdSingle);
final Map<TreeNodeSignature, Double> coefficientMap = calcCoefficientMap(residuals, quantile, tree);
adaptPreviousPrediction(previousPrediction, tree, coefficientMap);
models.add(tree);
coefficientMaps.add(coefficientMap);
exec.setProgress(((double) i) / nrModels, "Finished level " + i + "/" + nrModels);
}
return new GradientBoostedTreesModel(getConfig(), actualData.getMetaData(), models.toArray(new TreeModelRegression[models.size()]), actualData.getTreeType(), initialValue, coefficientMaps);
}
use of org.knime.base.node.mine.treeensemble2.model.GradientBoostedTreesModel in project knime-core by knime.
the class TreeEnsembleModel method load.
public static TreeEnsembleModel load(final InputStream in) throws IOException {
// wrapping the argument (zip input) stream in a buffered stream
// reduces read operation from, e.g. 42s to 2s
TreeModelDataInputStream input = new TreeModelDataInputStream(new BufferedInputStream(new NonClosableInputStream(in)));
int version = input.readInt();
if (version > 20160114) {
throw new IOException("Tree Ensemble version " + version + " not supported");
}
byte ensembleType;
if (version == 20160114) {
ensembleType = input.readByte();
} else {
ensembleType = 'r';
}
TreeType type = TreeType.load(input);
TreeMetaData metaData = TreeMetaData.load(input);
int nrModels = input.readInt();
boolean containsClassDistribution;
if (version == 20121019) {
containsClassDistribution = true;
} else {
containsClassDistribution = input.readBoolean();
}
input.setContainsClassDistribution(containsClassDistribution);
AbstractTreeModel[] models = new AbstractTreeModel[nrModels];
boolean isRegression = metaData.isRegression();
if (ensembleType != 'r') {
isRegression = true;
}
final TreeBuildingInterner treeBuildingInterner = new TreeBuildingInterner();
for (int i = 0; i < nrModels; i++) {
AbstractTreeModel singleModel;
try {
singleModel = isRegression ? TreeModelRegression.load(input, metaData, treeBuildingInterner) : TreeModelClassification.load(input, metaData, treeBuildingInterner);
if (input.readByte() != 0) {
throw new IOException("Model not terminated by 0 byte");
}
} catch (IOException e) {
throw new IOException("Can't read tree model " + (i + 1) + "/" + nrModels + ": " + e.getMessage(), e);
}
models[i] = singleModel;
}
TreeEnsembleModel result;
switch(ensembleType) {
case 'r':
result = new TreeEnsembleModel(metaData, models, type, containsClassDistribution);
break;
case 'g':
result = new GradientBoostingModel(metaData, models, type, containsClassDistribution);
break;
case 't':
result = new GradientBoostedTreesModel(metaData, models, type, containsClassDistribution);
break;
case 'm':
result = new MultiClassGradientBoostedTreesModel(metaData, models, type, containsClassDistribution);
break;
default:
throw new IllegalStateException("Unknown ensemble type: '" + (char) ensembleType + "'");
}
result.loadData(input);
// does not close the method argument stream!!
input.close();
return result;
}
use of org.knime.base.node.mine.treeensemble2.model.GradientBoostedTreesModel in project knime-core by knime.
the class RegressionGBTModelExporter method doWrite.
/**
* {@inheritDoc}
*/
@Override
protected void doWrite(final MiningModel model) {
// write the initial value
Targets targets = model.addNewTargets();
Target target = targets.addNewTarget();
GradientBoostedTreesModel gbtModel = getGBTModel();
target.setField(gbtModel.getMetaData().getTargetMetaData().getAttributeName());
target.setRescaleConstant(gbtModel.getInitialValue());
// write the model
Segmentation segmentation = model.addNewSegmentation();
List<TreeModelRegression> trees = IntStream.range(0, gbtModel.getNrModels()).mapToObj(gbtModel::getTreeModelRegression).collect(Collectors.toList());
writeSumSegmentation(segmentation, trees, gbtModel.getCoeffientMaps());
}
use of org.knime.base.node.mine.treeensemble2.model.GradientBoostedTreesModel in project knime-core by knime.
the class GradientBoostingPredictorCellFactory method createFactory.
public static GradientBoostingPredictorCellFactory createFactory(final GradientBoostingPredictor<GradientBoostedTreesModel> predictor) throws InvalidSettingsException {
TreeEnsembleModelPortObjectSpec modelSpec = predictor.getModelSpec();
DataTableSpec learnSpec = modelSpec.getLearnTableSpec();
DataTableSpec testSpec = predictor.getDataSpec();
UniqueNameGenerator nameGen = new UniqueNameGenerator(testSpec);
DataColumnSpec newColSpec = nameGen.newColumn(predictor.getConfiguration().getPredictionColumnName(), DoubleCell.TYPE);
return new GradientBoostingPredictorCellFactory(newColSpec, predictor.getModel(), learnSpec, modelSpec.calculateFilterIndices(testSpec));
}
Aggregations