use of org.knime.base.node.mine.treeensemble2.model.RegressionTreeModelPortObjectSpec in project knime-core by knime.
the class RegressionTreePredictorCellFactory method createFactory.
/**
* @param predictor
* @return factory based on RegressionTreePredictor <b>predictor</b>
* @throws InvalidSettingsException
*/
public static RegressionTreePredictorCellFactory createFactory(final RegressionTreePredictor predictor) throws InvalidSettingsException {
DataTableSpec testDataSpec = predictor.getDataSpec();
RegressionTreeModelPortObjectSpec modelSpec = predictor.getModelSpec();
RegressionTreePredictorConfiguration configuration = predictor.getConfiguration();
UniqueNameGenerator nameGen = new UniqueNameGenerator(testDataSpec);
List<DataColumnSpec> newColsList = new ArrayList<DataColumnSpec>();
String targetColName = configuration.getPredictionColumnName();
DataColumnSpec targetCol = nameGen.newColumn(targetColName, DoubleCell.TYPE);
newColsList.add(targetCol);
DataColumnSpec[] newCols = newColsList.toArray(new DataColumnSpec[newColsList.size()]);
int[] learnColumnInRealDataIndices = modelSpec.calculateFilterIndices(testDataSpec);
return new RegressionTreePredictorCellFactory(predictor, newCols, learnColumnInRealDataIndices);
}
use of org.knime.base.node.mine.treeensemble2.model.RegressionTreeModelPortObjectSpec in project knime-core by knime.
the class RegressionTreeLearnerNodeModel method execute.
/**
* {@inheritDoc}
*/
@Override
protected PortObject[] execute(final PortObject[] inObjects, final ExecutionContext exec) throws Exception {
BufferedDataTable t = (BufferedDataTable) inObjects[0];
DataTableSpec spec = t.getDataTableSpec();
final FilterLearnColumnRearranger learnRearranger = m_configuration.filterLearnColumns(spec);
String warn = learnRearranger.getWarning();
BufferedDataTable learnTable = exec.createColumnRearrangeTable(t, learnRearranger, exec.createSubProgress(0.0));
DataTableSpec learnSpec = learnTable.getDataTableSpec();
ExecutionMonitor readInExec = exec.createSubProgress(0.1);
ExecutionMonitor learnExec = exec.createSubProgress(0.9);
TreeDataCreator dataCreator = new TreeDataCreator(m_configuration, learnSpec, learnTable.getRowCount());
exec.setProgress("Reading data into memory");
TreeData data = dataCreator.readData(learnTable, m_configuration, readInExec);
m_hiliteRowSample = dataCreator.getDataRowsForHilite();
m_viewMessage = dataCreator.getViewMessage();
String dataCreationWarning = dataCreator.getAndClearWarningMessage();
if (dataCreationWarning != null) {
if (warn == null) {
warn = dataCreationWarning;
} else {
warn = warn + "\n" + dataCreationWarning;
}
}
readInExec.setProgress(1.0);
exec.setMessage("Learning tree");
RandomData rd = m_configuration.createRandomData();
final IDataIndexManager indexManager;
if (data.getTreeType() == TreeType.BitVector) {
indexManager = new BitVectorDataIndexManager(data.getNrRows());
} else {
indexManager = new DefaultDataIndexManager(data);
}
TreeNodeSignatureFactory signatureFactory = null;
int maxLevels = m_configuration.getMaxLevels();
if (maxLevels < TreeEnsembleLearnerConfiguration.MAX_LEVEL_INFINITE) {
int capacity = IntMath.pow(2, maxLevels - 1);
signatureFactory = new TreeNodeSignatureFactory(capacity);
} else {
signatureFactory = new TreeNodeSignatureFactory();
}
final RowSample rowSample = m_configuration.createRowSampler(data).createRowSample(rd);
TreeLearnerRegression treeLearner = new TreeLearnerRegression(m_configuration, data, indexManager, signatureFactory, rd, rowSample);
TreeModelRegression regTree = treeLearner.learnSingleTree(learnExec, rd);
RegressionTreeModel model = new RegressionTreeModel(m_configuration, data.getMetaData(), regTree, data.getTreeType());
RegressionTreeModelPortObjectSpec treePortObjectSpec = new RegressionTreeModelPortObjectSpec(learnSpec);
RegressionTreeModelPortObject treePortObject = new RegressionTreeModelPortObject(model, treePortObjectSpec);
learnExec.setProgress(1.0);
m_treeModelPortObject = treePortObject;
if (warn != null) {
setWarningMessage(warn);
}
return new PortObject[] { treePortObject };
}
use of org.knime.base.node.mine.treeensemble2.model.RegressionTreeModelPortObjectSpec in project knime-core by knime.
the class RegressionTreePMMLPredictorNodeModel method execute.
/**
* {@inheritDoc}
*/
@Override
public PortObject[] execute(final PortObject[] inObjects, final ExecutionContext exec) throws Exception {
PMMLPortObject pmmlPO = (PMMLPortObject) inObjects[0];
Pair<RegressionTreeModel, RegressionTreeModelPortObjectSpec> modelSpecPair = importModel(pmmlPO);
BufferedDataTable data = (BufferedDataTable) inObjects[1];
DataTableSpec dataSpec = data.getDataTableSpec();
// Can only happen if configure was not called before execute e.g. in generic PMML Predictor
if (m_configuration == null) {
m_configuration = RegressionTreePredictorConfiguration.createDefault(translateSpec(pmmlPO.getSpec()).getTargetColumn().getName());
}
final RegressionTreePredictor pred = new RegressionTreePredictor(modelSpecPair.getFirst(), modelSpecPair.getSecond(), dataSpec, m_configuration);
ColumnRearranger rearranger = pred.getPredictionRearranger();
BufferedDataTable outTable = exec.createColumnRearrangeTable(data, rearranger, exec);
return new BufferedDataTable[] { outTable };
}
use of org.knime.base.node.mine.treeensemble2.model.RegressionTreeModelPortObjectSpec in project knime-core by knime.
the class RegressionTreePMMLPredictorNodeModel method configure.
/**
* {@inheritDoc}
*/
@Override
protected PortObjectSpec[] configure(final PortObjectSpec[] inSpecs) throws InvalidSettingsException {
PMMLPortObjectSpec pmmlSpec = (PMMLPortObjectSpec) inSpecs[0];
DataType targetType = extractTargetType(pmmlSpec);
if (!targetType.isCompatible(DoubleValue.class)) {
throw new InvalidSettingsException("This node expects a regression model.");
}
try {
AbstractTreeModelPMMLTranslator.checkPMMLSpec(pmmlSpec);
} catch (IllegalArgumentException e) {
throw new InvalidSettingsException(e.getMessage());
}
RegressionTreeModelPortObjectSpec modelSpec = translateSpec(pmmlSpec);
String targetColName = modelSpec.getTargetColumn().getName();
if (m_configuration == null) {
m_configuration = RegressionTreePredictorConfiguration.createDefault(targetColName);
} else if (!m_configuration.isChangePredictionColumnName()) {
m_configuration.setPredictionColumnName(TreeEnsemblePredictorConfiguration.getPredictColumnName(targetColName));
}
DataTableSpec dataSpec = (DataTableSpec) inSpecs[1];
final RegressionTreePredictor pred = new RegressionTreePredictor(null, modelSpec, dataSpec, m_configuration);
return new PortObjectSpec[] { pred.getPredictionRearranger().createSpec() };
}
use of org.knime.base.node.mine.treeensemble2.model.RegressionTreeModelPortObjectSpec in project knime-core by knime.
the class RegressionTreePredictorPanel method loadSettingsFrom.
/**
* Loads the settings.
*
* @param settings
* @param specs
* @throws NotConfigurableException
*/
public void loadSettingsFrom(final NodeSettingsRO settings, final PortObjectSpec[] specs) throws NotConfigurableException {
final RegressionTreeModelPortObjectSpec modelSpec = (RegressionTreeModelPortObjectSpec) specs[0];
final DataColumnSpec targetSpec = modelSpec.getTargetColumn();
RegressionTreePredictorConfiguration config = new RegressionTreePredictorConfiguration(targetSpec.getName());
config.loadInDialog(settings);
String colName = config.getPredictionColumnName();
if (colName == null || colName.isEmpty()) {
colName = RegressionTreePredictorConfiguration.getPredictColumnName("");
}
m_predictionColNameField.setText(colName);
m_changePredictionColNameCheckBox.setSelected(config.isChangePredictionColumnName());
}
Aggregations