use of org.knime.base.node.mine.treeensemble2.learner.gradientboosting.AbstractGradientBoostingLearner in project knime-core by knime.
the class GradientBoostingClassificationLearnerNodeModel method execute.
/**
* {@inheritDoc}
*/
@Override
protected PortObject[] execute(final PortObject[] inData, final ExecutionContext exec) throws Exception {
BufferedDataTable t = (BufferedDataTable) inData[0];
DataTableSpec spec = t.getDataTableSpec();
final FilterLearnColumnRearranger learnRearranger = m_configuration.filterLearnColumns(spec);
String warn = learnRearranger.getWarning();
BufferedDataTable learnTable = exec.createColumnRearrangeTable(t, learnRearranger, exec.createSubProgress(0.0));
DataTableSpec learnSpec = learnTable.getDataTableSpec();
TreeEnsembleModelPortObjectSpec ensembleSpec = m_configuration.createPortObjectSpec(learnSpec);
ExecutionMonitor readInExec = exec.createSubProgress(0.1);
ExecutionMonitor learnExec = exec.createSubProgress(0.8);
TreeDataCreator dataCreator = new TreeDataCreator(m_configuration, learnSpec, learnTable.getRowCount());
exec.setProgress("Reading data into memory");
TreeData data = dataCreator.readData(learnTable, m_configuration, readInExec);
// m_hiliteRowSample = dataCreator.getDataRowsForHilite();
// m_viewMessage = dataCreator.getViewMessage();
String dataCreationWarning = dataCreator.getAndClearWarningMessage();
if (dataCreationWarning != null) {
if (warn == null) {
warn = dataCreationWarning;
} else {
warn = warn + "\n" + dataCreationWarning;
}
}
readInExec.setProgress(1.0);
exec.setMessage("Learning trees");
AbstractGradientBoostingLearner learner = new LKGradientBoostedTreesLearner(m_configuration, data);
AbstractGradientBoostingModel model;
// m_configuration.setMissingValueHandling(MissingValueHandling.XGBoost);
// try {
model = learner.learn(learnExec);
// } catch (ExecutionException e) {
// Throwable cause = e.getCause();
// if (cause instanceof Exception) {
// throw (Exception)cause;
// }
// throw e;
// }
GradientBoostingModelPortObject modelPortObject = new GradientBoostingModelPortObject(ensembleSpec, model);
learnExec.setProgress(1.0);
if (warn != null) {
setWarningMessage(warn);
}
return new PortObject[] { modelPortObject };
}
use of org.knime.base.node.mine.treeensemble2.learner.gradientboosting.AbstractGradientBoostingLearner in project knime-core by knime.
the class GradientBoostingRegressionLearnerNodeModel method execute.
/**
* {@inheritDoc}
*/
@Override
protected PortObject[] execute(final PortObject[] inData, final ExecutionContext exec) throws Exception {
BufferedDataTable t = (BufferedDataTable) inData[0];
DataTableSpec spec = t.getDataTableSpec();
final FilterLearnColumnRearranger learnRearranger = m_configuration.filterLearnColumns(spec);
String warn = learnRearranger.getWarning();
BufferedDataTable learnTable = exec.createColumnRearrangeTable(t, learnRearranger, exec.createSubProgress(0.0));
DataTableSpec learnSpec = learnTable.getDataTableSpec();
TreeEnsembleModelPortObjectSpec ensembleSpec = m_configuration.createPortObjectSpec(learnSpec);
ExecutionMonitor readInExec = exec.createSubProgress(0.1);
ExecutionMonitor learnExec = exec.createSubProgress(0.8);
ExecutionMonitor outOfBagExec = exec.createSubProgress(0.1);
TreeDataCreator dataCreator = new TreeDataCreator(m_configuration, learnSpec, learnTable.getRowCount());
exec.setProgress("Reading data into memory");
TreeData data = dataCreator.readData(learnTable, m_configuration, readInExec);
// m_hiliteRowSample = dataCreator.getDataRowsForHilite();
// m_viewMessage = dataCreator.getViewMessage();
String dataCreationWarning = dataCreator.getAndClearWarningMessage();
if (dataCreationWarning != null) {
if (warn == null) {
warn = dataCreationWarning;
} else {
warn = warn + "\n" + dataCreationWarning;
}
}
readInExec.setProgress(1.0);
exec.setMessage("Learning trees");
AbstractGradientBoostingLearner learner = new MGradientBoostedTreesLearner(m_configuration, data);
AbstractGradientBoostingModel model;
// try {
model = learner.learn(learnExec);
// } catch (ExecutionException e) {
// Throwable cause = e.getCause();
// if (cause instanceof Exception) {
// throw (Exception)cause;
// }
// throw e;
// }
GradientBoostingModelPortObject modelPortObject = new GradientBoostingModelPortObject(ensembleSpec, model);
learnExec.setProgress(1.0);
if (warn != null) {
setWarningMessage(warn);
}
return new PortObject[] { modelPortObject };
}
Aggregations