use of org.knime.base.node.mine.treeensemble2.model.AbstractGradientBoostingModel in project knime-core by knime.
the class GBTPMMLExporterNodeModel method execute.
/**
* {@inheritDoc}
*/
@Override
protected PortObject[] execute(final PortObject[] inObjects, final ExecutionContext exec) throws Exception {
GradientBoostingModelPortObject gbtPO = (GradientBoostingModelPortObject) inObjects[0];
AbstractGBTModelPMMLTranslator<?> translator;
AbstractGradientBoostingModel gbtModel = gbtPO.getEnsembleModel();
if (gbtModel instanceof GradientBoostedTreesModel) {
translator = new RegressionGBTModelPMMLTranslator((GradientBoostedTreesModel) gbtModel, gbtPO.getSpec().getLearnTableSpec());
} else if (gbtModel instanceof MultiClassGradientBoostedTreesModel) {
translator = new ClassificationGBTModelPMMLTranslator((MultiClassGradientBoostedTreesModel) gbtModel, gbtPO.getSpec().getLearnTableSpec());
} else {
throw new IllegalArgumentException("Unknown gradient boosted trees model type '" + gbtModel.getClass().getSimpleName() + "'.");
}
PMMLPortObjectSpec pmmlSpec = createPMMLSpec(gbtPO.getSpec(), gbtModel);
PMMLPortObject pmmlPO = new PMMLPortObject(pmmlSpec);
pmmlPO.addModelTranslater(translator);
return new PortObject[] { pmmlPO };
}
use of org.knime.base.node.mine.treeensemble2.model.AbstractGradientBoostingModel in project knime-core by knime.
the class MGradientBoostedTreesLearner method learn.
/**
* {@inheritDoc}
*/
@Override
public AbstractGradientBoostingModel learn(final ExecutionMonitor exec) throws CanceledExecutionException {
final TreeData actualData = getData();
final GradientBoostingLearnerConfiguration config = getConfig();
final int nrModels = config.getNrModels();
final TreeTargetNumericColumnData actualTarget = getTarget();
final double initialValue = actualTarget.getMedian();
final ArrayList<TreeModelRegression> models = new ArrayList<TreeModelRegression>(nrModels);
final ArrayList<Map<TreeNodeSignature, Double>> coefficientMaps = new ArrayList<Map<TreeNodeSignature, Double>>(nrModels);
final double[] previousPrediction = new double[actualTarget.getNrRows()];
Arrays.fill(previousPrediction, initialValue);
final RandomData rd = config.createRandomData();
final double alpha = config.getAlpha();
TreeNodeSignatureFactory signatureFactory = null;
final int maxLevels = config.getMaxLevels();
// this should be the default
if (maxLevels < TreeEnsembleLearnerConfiguration.MAX_LEVEL_INFINITE) {
final int capacity = IntMath.pow(2, maxLevels - 1);
signatureFactory = new TreeNodeSignatureFactory(capacity);
} else {
signatureFactory = new TreeNodeSignatureFactory();
}
exec.setMessage("Learning model");
TreeData residualData;
for (int i = 0; i < nrModels; i++) {
final double[] residuals = new double[actualTarget.getNrRows()];
for (int j = 0; j < actualTarget.getNrRows(); j++) {
residuals[j] = actualTarget.getValueFor(j) - previousPrediction[j];
}
final double quantile = calculateAlphaQuantile(residuals, alpha);
final double[] gradients = new double[residuals.length];
for (int j = 0; j < gradients.length; j++) {
gradients[j] = Math.abs(residuals[j]) <= quantile ? residuals[j] : quantile * Math.signum(residuals[j]);
}
residualData = createResidualDataFromArray(gradients, actualData);
final RandomData rdSingle = TreeEnsembleLearnerConfiguration.createRandomData(rd.nextLong(Long.MIN_VALUE, Long.MAX_VALUE));
final RowSample rowSample = getRowSampler().createRowSample(rdSingle);
final TreeLearnerRegression treeLearner = new TreeLearnerRegression(getConfig(), residualData, getIndexManager(), signatureFactory, rdSingle, rowSample);
final TreeModelRegression tree = treeLearner.learnSingleTree(exec, rdSingle);
final Map<TreeNodeSignature, Double> coefficientMap = calcCoefficientMap(residuals, quantile, tree);
adaptPreviousPrediction(previousPrediction, tree, coefficientMap);
models.add(tree);
coefficientMaps.add(coefficientMap);
exec.setProgress(((double) i) / nrModels, "Finished level " + i + "/" + nrModels);
}
return new GradientBoostedTreesModel(getConfig(), actualData.getMetaData(), models.toArray(new TreeModelRegression[models.size()]), actualData.getTreeType(), initialValue, coefficientMaps);
}
use of org.knime.base.node.mine.treeensemble2.model.AbstractGradientBoostingModel in project knime-core by knime.
the class GradientBoostingClassificationLearnerNodeModel method execute.
/**
* {@inheritDoc}
*/
@Override
protected PortObject[] execute(final PortObject[] inData, final ExecutionContext exec) throws Exception {
BufferedDataTable t = (BufferedDataTable) inData[0];
DataTableSpec spec = t.getDataTableSpec();
final FilterLearnColumnRearranger learnRearranger = m_configuration.filterLearnColumns(spec);
String warn = learnRearranger.getWarning();
BufferedDataTable learnTable = exec.createColumnRearrangeTable(t, learnRearranger, exec.createSubProgress(0.0));
DataTableSpec learnSpec = learnTable.getDataTableSpec();
TreeEnsembleModelPortObjectSpec ensembleSpec = m_configuration.createPortObjectSpec(learnSpec);
ExecutionMonitor readInExec = exec.createSubProgress(0.1);
ExecutionMonitor learnExec = exec.createSubProgress(0.8);
TreeDataCreator dataCreator = new TreeDataCreator(m_configuration, learnSpec, learnTable.getRowCount());
exec.setProgress("Reading data into memory");
TreeData data = dataCreator.readData(learnTable, m_configuration, readInExec);
// m_hiliteRowSample = dataCreator.getDataRowsForHilite();
// m_viewMessage = dataCreator.getViewMessage();
String dataCreationWarning = dataCreator.getAndClearWarningMessage();
if (dataCreationWarning != null) {
if (warn == null) {
warn = dataCreationWarning;
} else {
warn = warn + "\n" + dataCreationWarning;
}
}
readInExec.setProgress(1.0);
exec.setMessage("Learning trees");
AbstractGradientBoostingLearner learner = new LKGradientBoostedTreesLearner(m_configuration, data);
AbstractGradientBoostingModel model;
// m_configuration.setMissingValueHandling(MissingValueHandling.XGBoost);
// try {
model = learner.learn(learnExec);
// } catch (ExecutionException e) {
// Throwable cause = e.getCause();
// if (cause instanceof Exception) {
// throw (Exception)cause;
// }
// throw e;
// }
GradientBoostingModelPortObject modelPortObject = new GradientBoostingModelPortObject(ensembleSpec, model);
learnExec.setProgress(1.0);
if (warn != null) {
setWarningMessage(warn);
}
return new PortObject[] { modelPortObject };
}
use of org.knime.base.node.mine.treeensemble2.model.AbstractGradientBoostingModel in project knime-core by knime.
the class GradientBoostingRegressionLearnerNodeModel method execute.
/**
* {@inheritDoc}
*/
@Override
protected PortObject[] execute(final PortObject[] inData, final ExecutionContext exec) throws Exception {
BufferedDataTable t = (BufferedDataTable) inData[0];
DataTableSpec spec = t.getDataTableSpec();
final FilterLearnColumnRearranger learnRearranger = m_configuration.filterLearnColumns(spec);
String warn = learnRearranger.getWarning();
BufferedDataTable learnTable = exec.createColumnRearrangeTable(t, learnRearranger, exec.createSubProgress(0.0));
DataTableSpec learnSpec = learnTable.getDataTableSpec();
TreeEnsembleModelPortObjectSpec ensembleSpec = m_configuration.createPortObjectSpec(learnSpec);
ExecutionMonitor readInExec = exec.createSubProgress(0.1);
ExecutionMonitor learnExec = exec.createSubProgress(0.8);
ExecutionMonitor outOfBagExec = exec.createSubProgress(0.1);
TreeDataCreator dataCreator = new TreeDataCreator(m_configuration, learnSpec, learnTable.getRowCount());
exec.setProgress("Reading data into memory");
TreeData data = dataCreator.readData(learnTable, m_configuration, readInExec);
// m_hiliteRowSample = dataCreator.getDataRowsForHilite();
// m_viewMessage = dataCreator.getViewMessage();
String dataCreationWarning = dataCreator.getAndClearWarningMessage();
if (dataCreationWarning != null) {
if (warn == null) {
warn = dataCreationWarning;
} else {
warn = warn + "\n" + dataCreationWarning;
}
}
readInExec.setProgress(1.0);
exec.setMessage("Learning trees");
AbstractGradientBoostingLearner learner = new MGradientBoostedTreesLearner(m_configuration, data);
AbstractGradientBoostingModel model;
// try {
model = learner.learn(learnExec);
// } catch (ExecutionException e) {
// Throwable cause = e.getCause();
// if (cause instanceof Exception) {
// throw (Exception)cause;
// }
// throw e;
// }
GradientBoostingModelPortObject modelPortObject = new GradientBoostingModelPortObject(ensembleSpec, model);
learnExec.setProgress(1.0);
if (warn != null) {
setWarningMessage(warn);
}
return new PortObject[] { modelPortObject };
}
Aggregations