use of org.knime.core.data.DataCell in project knime-core by knime.
the class DecTreePredictorNodeModel method execute.
/**
* {@inheritDoc}
*/
@Override
public PortObject[] execute(final PortObject[] inPorts, final ExecutionContext exec) throws CanceledExecutionException, Exception {
exec.setMessage("Decision Tree Predictor: Loading predictor...");
PMMLPortObject port = (PMMLPortObject) inPorts[INMODELPORT];
List<Node> models = port.getPMMLValue().getModels(PMMLModelType.TreeModel);
if (models.isEmpty()) {
String msg = "Decision Tree evaluation failed: " + "No tree model found.";
LOGGER.error(msg);
throw new RuntimeException(msg);
}
PMMLDecisionTreeTranslator trans = new PMMLDecisionTreeTranslator();
port.initializeModelTranslator(trans);
DecisionTree decTree = trans.getDecisionTree();
decTree.resetColorInformation();
BufferedDataTable inData = (BufferedDataTable) inPorts[INDATAPORT];
// get column with color information
String colorColumn = null;
for (DataColumnSpec s : inData.getDataTableSpec()) {
if (s.getColorHandler() != null) {
colorColumn = s.getName();
break;
}
}
decTree.setColorColumn(colorColumn);
exec.setMessage("Decision Tree Predictor: start execution.");
PortObjectSpec[] inSpecs = new PortObjectSpec[] { inPorts[0].getSpec(), inPorts[1].getSpec() };
DataTableSpec outSpec = createOutTableSpec(inSpecs);
BufferedDataContainer outData = exec.createDataContainer(outSpec);
long coveredPattern = 0;
long nrPattern = 0;
long rowCount = 0;
long numberRows = inData.size();
exec.setMessage("Classifying...");
for (DataRow thisRow : inData) {
DataCell cl = null;
LinkedHashMap<String, Double> classDistrib = null;
try {
Pair<DataCell, LinkedHashMap<DataCell, Double>> pair = decTree.getWinnerAndClasscounts(thisRow, inData.getDataTableSpec());
cl = pair.getFirst();
LinkedHashMap<DataCell, Double> classCounts = pair.getSecond();
classDistrib = getDistribution(classCounts);
if (coveredPattern < m_maxNumCoveredPattern.getIntValue()) {
// remember this one for HiLite support
decTree.addCoveredPattern(thisRow, inData.getDataTableSpec());
coveredPattern++;
} else {
// too many patterns for HiLite - at least remember color
decTree.addCoveredColor(thisRow, inData.getDataTableSpec());
}
nrPattern++;
} catch (Exception e) {
LOGGER.error("Decision Tree evaluation failed: " + e.getMessage());
throw e;
}
if (cl == null) {
LOGGER.error("Decision Tree evaluation failed: result empty");
throw new Exception("Decision Tree evaluation failed.");
}
DataCell[] newCells = new DataCell[outSpec.getNumColumns()];
int numInCells = thisRow.getNumCells();
for (int i = 0; i < numInCells; i++) {
newCells[i] = thisRow.getCell(i);
}
if (m_showDistribution.getBooleanValue()) {
for (int i = numInCells; i < newCells.length - 1; i++) {
String predClass = outSpec.getColumnSpec(i).getName();
if (classDistrib != null && classDistrib.get(predClass) != null) {
newCells[i] = new DoubleCell(classDistrib.get(predClass));
} else {
newCells[i] = new DoubleCell(0.0);
}
}
}
newCells[newCells.length - 1] = cl;
outData.addRowToTable(new DefaultRow(thisRow.getKey(), newCells));
rowCount++;
if (rowCount % 100 == 0) {
exec.setProgress(rowCount / (double) numberRows, "Classifying... Row " + rowCount + " of " + numberRows);
}
exec.checkCanceled();
}
if (coveredPattern < nrPattern) {
// let the user know that we did not store all available pattern
// for HiLiting.
this.setWarningMessage("Tree only stored first " + m_maxNumCoveredPattern.getIntValue() + " (of " + nrPattern + ") rows for HiLiting!");
}
outData.close();
m_decTree = decTree;
exec.setMessage("Decision Tree Predictor: end execution.");
return new BufferedDataTable[] { outData.getTable() };
}
use of org.knime.core.data.DataCell in project knime-core by knime.
the class DecTreePredictorNodeModel method createOutTableSpec.
private DataTableSpec createOutTableSpec(final PortObjectSpec[] inSpecs) {
LinkedList<DataCell> predValues = null;
if (m_showDistribution.getBooleanValue()) {
predValues = getPredictionValues((PMMLPortObjectSpec) inSpecs[INMODELPORT]);
if (predValues == null) {
// no out spec can be determined
return null;
}
}
int numCols = (predValues == null ? 0 : predValues.size()) + 1;
DataTableSpec inSpec = (DataTableSpec) inSpecs[INDATAPORT];
UniqueNameGenerator nameGenerator = new UniqueNameGenerator(inSpec);
DataColumnSpec[] newCols = new DataColumnSpec[numCols];
/* Set bar renderer and domain [0,1] as default for the double cells
* containing the distribution */
// DataColumnProperties propsRendering = new DataColumnProperties(
// Collections.singletonMap(
// DataValueRenderer.PROPERTY_PREFERRED_RENDERER,
// DoubleBarRenderer.DESCRIPTION));
DataColumnDomain domain = new DataColumnDomainCreator(new DoubleCell(0.0), new DoubleCell(1.0)).createDomain();
// add all distribution columns
for (int i = 0; i < numCols - 1; i++) {
DataColumnSpecCreator colSpecCreator = nameGenerator.newCreator(predValues.get(i).toString(), DoubleCell.TYPE);
// colSpecCreator.setProperties(propsRendering);
colSpecCreator.setDomain(domain);
newCols[i] = colSpecCreator.createSpec();
}
// add the prediction column
newCols[numCols - 1] = nameGenerator.newColumn("Prediction (DecTree)", StringCell.TYPE);
DataTableSpec newColSpec = new DataTableSpec(newCols);
return new DataTableSpec(inSpec, newColSpec);
}
use of org.knime.core.data.DataCell in project knime-core by knime.
the class LinearRegressionContent method predict.
/**
* Predicts the target value for the given row.
*
* @param row a data row to predict
* @return the predicted value in a data cell
*/
public DataCell predict(final DataRow row) {
double sum = m_offset;
for (int i = 0; i < row.getNumCells(); i++) {
DataCell c = row.getCell(i);
if (c.isMissing()) {
return DataType.getMissingCell();
}
double d = ((DoubleCell) c).getDoubleValue();
sum += m_multipliers[i] * d;
}
return new DoubleCell(sum);
}
use of org.knime.core.data.DataCell in project knime-core by knime.
the class LogRegPredictor method determineTargetCategories.
/**
* Retrieve the target values from the PMML model.
* @throws InvalidSettingsException if PMML model is inconsistent or ambiguous
*/
private static List<DataCell> determineTargetCategories(final DataColumnSpec targetCol, final PMMLGeneralRegressionContent content) throws InvalidSettingsException {
Map<String, DataCell> domainValues = new HashMap<String, DataCell>();
for (DataCell cell : targetCol.getDomain().getValues()) {
domainValues.put(cell.toString(), cell);
}
// Collect target categories from model
Set<DataCell> modelTargetCategories = new LinkedHashSet<DataCell>();
for (PMMLPCell cell : content.getParamMatrix()) {
modelTargetCategories.add(domainValues.get(cell.getTargetCategory()));
}
String targetReferenceCategory = content.getTargetReferenceCategory();
if (targetReferenceCategory == null || targetReferenceCategory.isEmpty()) {
List<DataCell> targetCategories = new ArrayList<DataCell>();
targetCategories.addAll(targetCol.getDomain().getValues());
Collections.sort(targetCategories, targetCol.getType().getComparator());
if (targetCategories.size() == modelTargetCategories.size() + 1) {
targetReferenceCategory = targetCategories.get(targetCategories.size() - 1).toString();
// the last target category is the target reference category
LOGGER.debug("The target reference category is not explicitly set in PMML. Automatically choose : " + targetReferenceCategory);
} else {
throw new InvalidSettingsException("Please set the attribute \"targetReferenceCategory\" of the" + "\"GeneralRegression\" element in the PMML file.");
}
}
modelTargetCategories.add(domainValues.get(targetReferenceCategory));
List<DataCell> toReturn = new ArrayList<DataCell>();
toReturn.addAll(modelTargetCategories);
return toReturn;
}
use of org.knime.core.data.DataCell in project knime-core by knime.
the class LogRegPredictor method getCells.
/**
* {@inheritDoc}
*/
@Override
public DataCell[] getCells(final DataRow row) {
if (hasMissingValues(row)) {
return createMissingOutput();
}
final MissingHandling missingHandling = new MissingHandling(true);
DataCell[] cells = m_includeProbs ? new DataCell[1 + m_targetDomainValuesCount] : new DataCell[1];
Arrays.fill(cells, new IntCell(0));
// column vector
final RealMatrix x = MatrixUtils.createRealMatrix(1, m_parameters.size());
for (int i = 0; i < m_parameters.size(); i++) {
String parameter = m_parameters.get(i);
String predictor = null;
String value = null;
boolean rowIsEmpty = true;
for (final Iterator<String> iter = m_predictors.iterator(); iter.hasNext(); ) {
predictor = iter.next();
value = m_ppMatrix.getValue(parameter, predictor, null);
if (null != value) {
rowIsEmpty = false;
break;
}
}
if (rowIsEmpty) {
x.setEntry(0, i, 1);
} else {
if (m_factors.contains(predictor)) {
List<DataCell> values = m_values.get(predictor);
DataCell cell = row.getCell(m_parameterI.get(parameter));
int index = values.indexOf(cell);
/* When building a general regression model, for each
categorical fields, there is one category used as the
default baseline and therefore it didn't show in the
ParameterList in PMML. This design for the training is fine,
but in the prediction, when the input of Employment is
the default baseline, the parameters should all be 0.
See the commit message for an example and more details.
*/
if (index > 0) {
x.setEntry(0, i + index - 1, 1);
i += values.size() - 2;
}
} else if (m_baseLabelToColName.containsKey(parameter) && m_vectorLengths.containsKey(m_baseLabelToColName.get(parameter))) {
final DataCell cell = row.getCell(m_parameterI.get(parameter));
Optional<NameAndIndex> vectorValue = VectorHandling.parse(predictor);
if (vectorValue.isPresent()) {
int j = vectorValue.get().getIndex();
value = m_ppMatrix.getValue(parameter, predictor, null);
double exponent = Integer.valueOf(value);
double radix = RegressionTrainingRow.getValue(cell, j, missingHandling);
x.setEntry(0, i, Math.pow(radix, exponent));
}
} else {
DataCell cell = row.getCell(m_parameterI.get(parameter));
double radix = ((DoubleValue) cell).getDoubleValue();
double exponent = Integer.valueOf(value);
x.setEntry(0, i, Math.pow(radix, exponent));
}
}
}
// column vector
RealMatrix r = x.multiply(m_beta);
// determine the column with highest probability
int maxIndex = 0;
double maxValue = r.getEntry(0, 0);
for (int i = 1; i < r.getColumnDimension(); i++) {
if (r.getEntry(0, i) > maxValue) {
maxValue = r.getEntry(0, i);
maxIndex = i;
}
}
if (m_includeProbs) {
// compute probabilities of the target categories
for (int i = 0; i < m_targetCategories.size(); i++) {
// test if calculation would overflow
boolean overflow = false;
for (int k = 0; k < r.getColumnDimension(); k++) {
if ((r.getEntry(0, k) - r.getEntry(0, i)) > 700) {
overflow = true;
}
}
if (!overflow) {
double sum = 0;
for (int k = 0; k < r.getColumnDimension(); k++) {
sum += Math.exp(r.getEntry(0, k) - r.getEntry(0, i));
}
cells[m_targetCategoryIndex.get(i)] = new DoubleCell(1.0 / sum);
} else {
cells[m_targetCategoryIndex.get(i)] = new DoubleCell(0);
}
}
}
// the last cell is the prediction
cells[cells.length - 1] = m_targetCategories.get(maxIndex);
return cells;
}
Aggregations