use of org.knime.base.node.mine.treeensemble2.data.PredictorRecord in project knime-core by knime.
the class LKGradientBoostingPredictorCellFactory method getCells.
/**
* {@inheritDoc}
*/
@Override
public DataCell[] getCells(final DataRow row) {
final DataRow filterRow = new FilterColumnRow(row, m_learnColumnInRealDataIndices);
final int nrClasses = m_model.getNrClasses();
final int nrLevels = m_model.getNrLevels();
final PredictorRecord record = m_model.createPredictorRecord(filterRow, m_learnSpec);
final double[] classFunctionPredictions = new double[nrClasses];
Arrays.fill(classFunctionPredictions, m_model.getInitialValue());
for (int i = 0; i < nrLevels; i++) {
for (int j = 0; j < nrClasses; j++) {
final TreeNodeRegression matchingNode = m_model.getModel(i, j).findMatchingNode(record);
classFunctionPredictions[j] += m_model.getCoefficientMap(i, j).get(matchingNode.getSignature());
}
}
final double[] classProbabilities = new double[nrClasses];
double expSum = 0;
for (int i = 0; i < nrClasses; i++) {
classProbabilities[i] = Math.exp(classFunctionPredictions[i]);
expSum += classProbabilities[i];
}
int classIdx = -1;
double classProb = -1;
for (int i = 0; i < nrClasses; i++) {
classProbabilities[i] /= expSum;
if (classProbabilities[i] > classProb) {
classIdx = i;
classProb = classProbabilities[i];
}
}
final ArrayList<DataCell> cells = new ArrayList<DataCell>();
cells.add(new StringCell(m_model.getClassLabel(classIdx)));
if (m_config.isAppendPredictionConfidence()) {
cells.add(new DoubleCell(classProb));
}
if (m_config.isAppendClassConfidences()) {
// the map is necessary to ensure that the probabilities are correctly associated with the column header
final Map<String, Double> classProbMap = new HashMap<String, Double>((int) (nrClasses * 1.5));
for (int i = 0; i < nrClasses; i++) {
classProbMap.put(m_model.getClassLabel(i), classProbabilities[i]);
}
for (final String className : m_targetValueMap.keySet()) {
cells.add(new DoubleCell(classProbMap.get(className)));
}
}
return cells.toArray(new DataCell[cells.size()]);
}
use of org.knime.base.node.mine.treeensemble2.data.PredictorRecord in project knime-core by knime.
the class TreeEnsembleRegressionPredictorCellFactory method getCells.
/**
* {@inheritDoc}
*/
@Override
public DataCell[] getCells(final DataRow row) {
TreeEnsembleModelPortObject modelObject = m_predictor.getModelObject();
TreeEnsemblePredictorConfiguration cfg = m_predictor.getConfiguration();
final TreeEnsembleModel ensembleModel = modelObject.getEnsembleModel();
int size = 1;
final boolean appendConfidence = cfg.isAppendPredictionConfidence();
final boolean appendModelCount = cfg.isAppendModelCount();
if (appendConfidence) {
size += 1;
}
if (appendModelCount) {
size += 1;
}
final boolean hasOutOfBagFilter = m_predictor.hasOutOfBagFilter();
DataCell[] result = new DataCell[size];
DataRow filterRow = new FilterColumnRow(row, m_learnColumnInRealDataIndices);
PredictorRecord record = ensembleModel.createPredictorRecord(filterRow, m_learnSpec);
if (record == null) {
// missing value
Arrays.fill(result, DataType.getMissingCell());
return result;
}
Mean mean = new Mean();
Variance variance = new Variance();
final int nrModels = ensembleModel.getNrModels();
for (int i = 0; i < nrModels; i++) {
if (hasOutOfBagFilter && m_predictor.isRowPartOfTrainingData(row.getKey(), i)) {
// ignore, row was used to train the model
} else {
TreeModelRegression m = ensembleModel.getTreeModelRegression(i);
TreeNodeRegression match = m.findMatchingNode(record);
double nodeMean = match.getMean();
mean.increment(nodeMean);
variance.increment(nodeMean);
}
}
int nrValidModels = (int) mean.getN();
int index = 0;
result[index++] = nrValidModels == 0 ? DataType.getMissingCell() : new DoubleCell(mean.getResult());
if (appendConfidence) {
result[index++] = nrValidModels == 0 ? DataType.getMissingCell() : new DoubleCell(variance.getResult());
}
if (appendModelCount) {
result[index++] = new IntCell(nrValidModels);
}
return result;
}
use of org.knime.base.node.mine.treeensemble2.data.PredictorRecord in project knime-core by knime.
the class TreeNodeNumericConditionTest method testTestCondition.
/**
* This method tests the
* {@link TreeNodeNominalCondition#testCondition(org.knime.base.node.mine.treeensemble2.data.PredictorRecord)}
* method.
*
* @throws Exception
*/
@Test
public void testTestCondition() throws Exception {
final TreeEnsembleLearnerConfiguration config = new TreeEnsembleLearnerConfiguration(false);
final TestDataGenerator dataGen = new TestDataGenerator(config);
final TreeNumericColumnData col = dataGen.createNumericAttributeColumn("1,2,3,4,4,5,6,7", "testCol", 0);
TreeNodeNumericCondition cond = new TreeNodeNumericCondition(col.getMetaData(), 3, NumericOperator.LessThanOrEqual, false);
final Map<String, Object> map = Maps.newHashMap();
final String colName = col.getMetaData().getAttributeName();
map.put(colName, 2.5);
final PredictorRecord record = new PredictorRecord(map);
assertTrue("2.5 was falsely rejected.", cond.testCondition(record));
map.clear();
map.put(colName, 3);
assertTrue("3 was falsely rejected.", cond.testCondition(record));
map.clear();
map.put(colName, 4);
assertFalse("4 was falsely accepted.", cond.testCondition(record));
map.clear();
map.put(colName, PredictorRecord.NULL);
assertFalse("Missing values were falsely accepted.", cond.testCondition(record));
cond = new TreeNodeNumericCondition(col.getMetaData(), 3, NumericOperator.LessThanOrEqual, true);
map.clear();
map.put(colName, 2.5);
assertTrue("2.5 was falsely rejected.", cond.testCondition(record));
map.clear();
map.put(colName, 3);
assertTrue("3 was falsely rejected.", cond.testCondition(record));
map.clear();
map.put(colName, 4);
assertFalse("4 was falsely accepted.", cond.testCondition(record));
map.clear();
map.put(colName, PredictorRecord.NULL);
assertTrue("Missing values were falsely rejected.", cond.testCondition(record));
cond = new TreeNodeNumericCondition(col.getMetaData(), 4, NumericOperator.LargerThan, false);
map.clear();
map.put(colName, 2.5);
assertFalse("2.5 was falsely accepted.", cond.testCondition(record));
map.clear();
map.put(colName, 3);
assertFalse("3 was falsely accepted.", cond.testCondition(record));
map.clear();
map.put(colName, 4);
assertFalse("4 was falsely accepted.", cond.testCondition(record));
map.clear();
map.put(colName, 4.01);
assertTrue("4.01 was falsely rejected.", cond.testCondition(record));
map.clear();
map.put(colName, PredictorRecord.NULL);
assertFalse("Missing values were falsely accepted.", cond.testCondition(record));
cond = new TreeNodeNumericCondition(col.getMetaData(), 4, NumericOperator.LargerThan, true);
map.clear();
map.put(colName, 2.5);
assertFalse("2.5 was falsely accepted.", cond.testCondition(record));
map.clear();
map.put(colName, 3);
assertFalse("3 was falsely accepted.", cond.testCondition(record));
map.clear();
map.put(colName, 4.01);
assertTrue("4 was falsely rejected.", cond.testCondition(record));
map.clear();
map.put(colName, PredictorRecord.NULL);
assertTrue("Missing values were falsely rejected.", cond.testCondition(record));
}
use of org.knime.base.node.mine.treeensemble2.data.PredictorRecord in project knime-core by knime.
the class AbstractGradientBoostingLearner method createPredictorRecord.
/**
* Creates a PredictorRecord from the inMemory TreeData object
*
* @param data
* @param indexManager
* @param rowIdx
* @return a PredictorRecord for the row at <b>rowIdx</b> in <b>data</b>
*/
public static PredictorRecord createPredictorRecord(final TreeData data, final IDataIndexManager indexManager, final int rowIdx) {
Map<String, Object> valMap = new HashMap<String, Object>();
for (TreeAttributeColumnData column : data.getColumns()) {
TreeAttributeColumnMetaData meta = column.getMetaData();
valMap.put(meta.getAttributeName(), handleMissingValues(column.getValueAt(indexManager.getPositionsInColumn(meta.getAttributeIndex())[rowIdx]), column));
}
return new PredictorRecord(valMap);
}
use of org.knime.base.node.mine.treeensemble2.data.PredictorRecord in project knime-core by knime.
the class AbstractTreeEnsembleModel method createBitVectorPredictorRecord.
private PredictorRecord createBitVectorPredictorRecord(final DataRow filterRow) {
assert filterRow.getNumCells() == 1 : "Expected one cell as bit vector data";
DataCell c = filterRow.getCell(0);
if (c.isMissing()) {
return null;
}
BitVectorValue bv = (BitVectorValue) c;
final long length = bv.length();
if (length != getMetaData().getNrAttributes()) {
throw new IllegalArgumentException("The bit-vector in " + filterRow.getKey().getString() + " has the wrong length. (" + length + " instead of " + getMetaData().getNrAttributes() + ")");
}
Map<String, Object> valueMap = new LinkedHashMap<String, Object>((int) (length / 0.75 + 1.0));
for (int i = 0; i < length; i++) {
valueMap.put(TreeBitColumnMetaData.getAttributeName(i), Boolean.valueOf(bv.get(i)));
}
return new PredictorRecord(valueMap);
}
Aggregations