use of org.knime.base.data.filter.column.FilterColumnRow in project knime-core by knime.
the class BasisFunctionPredictorCellFactory method getCells.
/**
* Predicts given row using the underlying basis function model.
* {@inheritDoc}
*/
@Override
public DataCell[] getCells(final DataRow row) {
DataRow wRow = new FilterColumnRow(row, m_filteredColumns);
DataCell[] pred = predict(wRow, m_model);
if (m_appendClassProps) {
// complete prediction including class probs and label
return pred;
} else {
// don't append class probabilities
return new DataCell[] { pred[pred.length - 1] };
}
}
use of org.knime.base.data.filter.column.FilterColumnRow in project knime-core by knime.
the class LKGradientBoostingPredictorCellFactory method getCells.
/**
* {@inheritDoc}
*/
@Override
public DataCell[] getCells(final DataRow row) {
final DataRow filterRow = new FilterColumnRow(row, m_learnColumnInRealDataIndices);
final int nrClasses = m_model.getNrClasses();
final int nrLevels = m_model.getNrLevels();
final PredictorRecord record = m_model.createPredictorRecord(filterRow, m_learnSpec);
final double[] classFunctionPredictions = new double[nrClasses];
Arrays.fill(classFunctionPredictions, m_model.getInitialValue());
for (int i = 0; i < nrLevels; i++) {
for (int j = 0; j < nrClasses; j++) {
final TreeNodeRegression matchingNode = m_model.getModel(i, j).findMatchingNode(record);
classFunctionPredictions[j] += m_model.getCoefficientMap(i, j).get(matchingNode.getSignature());
}
}
final double[] classProbabilities = new double[nrClasses];
double expSum = 0;
for (int i = 0; i < nrClasses; i++) {
classProbabilities[i] = Math.exp(classFunctionPredictions[i]);
expSum += classProbabilities[i];
}
int classIdx = -1;
double classProb = -1;
for (int i = 0; i < nrClasses; i++) {
classProbabilities[i] /= expSum;
if (classProbabilities[i] > classProb) {
classIdx = i;
classProb = classProbabilities[i];
}
}
final ArrayList<DataCell> cells = new ArrayList<DataCell>();
cells.add(new StringCell(m_model.getClassLabel(classIdx)));
if (m_config.isAppendPredictionConfidence()) {
cells.add(new DoubleCell(classProb));
}
if (m_config.isAppendClassConfidences()) {
// the map is necessary to ensure that the probabilities are correctly associated with the column header
final Map<String, Double> classProbMap = new HashMap<String, Double>((int) (nrClasses * 1.5));
for (int i = 0; i < nrClasses; i++) {
classProbMap.put(m_model.getClassLabel(i), classProbabilities[i]);
}
for (final String className : m_targetValueMap.keySet()) {
cells.add(new DoubleCell(classProbMap.get(className)));
}
}
return cells.toArray(new DataCell[cells.size()]);
}
use of org.knime.base.data.filter.column.FilterColumnRow in project knime-core by knime.
the class TreeEnsembleRegressionPredictorCellFactory method getCells.
/**
* {@inheritDoc}
*/
@Override
public DataCell[] getCells(final DataRow row) {
TreeEnsembleModelPortObject modelObject = m_predictor.getModelObject();
TreeEnsemblePredictorConfiguration cfg = m_predictor.getConfiguration();
final TreeEnsembleModel ensembleModel = modelObject.getEnsembleModel();
int size = 1;
final boolean appendConfidence = cfg.isAppendPredictionConfidence();
final boolean appendModelCount = cfg.isAppendModelCount();
if (appendConfidence) {
size += 1;
}
if (appendModelCount) {
size += 1;
}
final boolean hasOutOfBagFilter = m_predictor.hasOutOfBagFilter();
DataCell[] result = new DataCell[size];
DataRow filterRow = new FilterColumnRow(row, m_learnColumnInRealDataIndices);
PredictorRecord record = ensembleModel.createPredictorRecord(filterRow, m_learnSpec);
if (record == null) {
// missing value
Arrays.fill(result, DataType.getMissingCell());
return result;
}
Mean mean = new Mean();
Variance variance = new Variance();
final int nrModels = ensembleModel.getNrModels();
for (int i = 0; i < nrModels; i++) {
if (hasOutOfBagFilter && m_predictor.isRowPartOfTrainingData(row.getKey(), i)) {
// ignore, row was used to train the model
} else {
TreeModelRegression m = ensembleModel.getTreeModelRegression(i);
TreeNodeRegression match = m.findMatchingNode(record);
double nodeMean = match.getMean();
mean.increment(nodeMean);
variance.increment(nodeMean);
}
}
int nrValidModels = (int) mean.getN();
int index = 0;
result[index++] = nrValidModels == 0 ? DataType.getMissingCell() : new DoubleCell(mean.getResult());
if (appendConfidence) {
result[index++] = nrValidModels == 0 ? DataType.getMissingCell() : new DoubleCell(variance.getResult());
}
if (appendModelCount) {
result[index++] = new IntCell(nrValidModels);
}
return result;
}
use of org.knime.base.data.filter.column.FilterColumnRow in project knime-core by knime.
the class RegressionTreePredictorCellFactory method getCells.
/**
* {@inheritDoc}
*/
@Override
public DataCell[] getCells(final DataRow row) {
final RegressionTreeModel treeModel = m_predictor.getModel();
int size = 1;
DataCell[] result = new DataCell[size];
DataRow filterRow = new FilterColumnRow(row, m_learnColumnInRealDataIndices);
PredictorRecord record = treeModel.createPredictorRecord(filterRow, m_learnSpec);
if (record == null) {
// missing value
Arrays.fill(result, DataType.getMissingCell());
return result;
}
TreeModelRegression tree = treeModel.getTreeModel();
TreeNodeRegression match = tree.findMatchingNode(record);
double nodeMean = match.getMean();
result[0] = new DoubleCell(nodeMean);
return result;
}
use of org.knime.base.data.filter.column.FilterColumnRow in project knime-core by knime.
the class UnpivotNodeModel method execute.
/**
* {@inheritDoc}
*/
@Override
protected BufferedDataTable[] execute(final BufferedDataTable[] inData, final ExecutionContext exec) throws Exception {
DataTableSpec inSpec = inData[0].getSpec();
List<String> orderColumns = m_orderColumns.getIncludeList();
List<String> valueColumns = m_valueColumns.getIncludeList();
int[] orderColumnIdx = new int[orderColumns.size()];
for (int i = 0; i < orderColumnIdx.length; i++) {
orderColumnIdx[i] = inSpec.findColumnIndex(orderColumns.get(i));
}
final double newRowCnt = inData[0].getRowCount() * valueColumns.size();
final boolean enableHilite = m_enableHilite.getBooleanValue();
LinkedHashMap<RowKey, Set<RowKey>> map = new LinkedHashMap<RowKey, Set<RowKey>>();
DataTableSpec outSpec = createOutSpec(inSpec);
BufferedDataContainer buf = exec.createDataContainer(outSpec);
for (DataRow row : inData[0]) {
LinkedHashSet<RowKey> set = new LinkedHashSet<RowKey>();
FilterColumnRow crow = new FilterColumnRow(row, orderColumnIdx);
for (int i = 0; i < valueColumns.size(); i++) {
String colName = valueColumns.get(i);
DataCell acell = row.getCell(inSpec.findColumnIndex(colName));
if (acell.isMissing() && m_missingValues.getBooleanValue()) {
// skip rows containing missing cells (in Value column(s))
continue;
}
RowKey rowKey = RowKey.createRowKey(buf.size());
if (enableHilite) {
set.add(rowKey);
}
DefaultRow drow = new DefaultRow(rowKey, new StringCell(row.getKey().getString()), new StringCell(colName), acell);
buf.addRowToTable(new AppendedColumnRow(rowKey, drow, crow));
exec.checkCanceled();
exec.setProgress(buf.size() / newRowCnt);
}
if (enableHilite) {
map.put(crow.getKey(), set);
}
}
buf.close();
if (enableHilite) {
m_trans.setMapper(new DefaultHiLiteMapper(map));
} else {
m_trans.setMapper(null);
}
return new BufferedDataTable[] { buf.getTable() };
}
Aggregations