use of org.knime.core.data.def.StringCell in project knime-core by knime.
the class TreeEnsembleClassificationPredictorCellFactory2 method getCells.
/**
* {@inheritDoc}
*/
@Override
public DataCell[] getCells(final DataRow row) {
TreeEnsembleModelPortObject modelObject = m_predictor.getModelObject();
TreeEnsemblePredictorConfiguration cfg = m_predictor.getConfiguration();
final TreeEnsembleModel ensembleModel = modelObject.getEnsembleModel();
int size = 1;
final boolean appendConfidence = cfg.isAppendPredictionConfidence();
if (appendConfidence) {
size += 1;
}
final boolean appendClassConfidences = cfg.isAppendClassConfidences();
if (appendClassConfidences) {
size += m_targetValueMap.size();
}
final boolean appendModelCount = cfg.isAppendModelCount();
if (appendModelCount) {
size += 1;
}
final boolean hasOutOfBagFilter = m_predictor.hasOutOfBagFilter();
DataCell[] result = new DataCell[size];
DataRow filterRow = new FilterColumnRow(row, m_learnColumnInRealDataIndices);
PredictorRecord record = ensembleModel.createPredictorRecord(filterRow, m_learnSpec);
if (record == null) {
// missing value
Arrays.fill(result, DataType.getMissingCell());
return result;
}
OccurrenceCounter<String> counter = new OccurrenceCounter<String>();
final int nrModels = ensembleModel.getNrModels();
TreeTargetNominalColumnMetaData targetMeta = (TreeTargetNominalColumnMetaData) ensembleModel.getMetaData().getTargetMetaData();
final double[] classProbabilities = new double[targetMeta.getValues().length];
int nrValidModels = 0;
for (int i = 0; i < nrModels; i++) {
if (hasOutOfBagFilter && m_predictor.isRowPartOfTrainingData(row.getKey(), i)) {
// ignore, row was used to train the model
} else {
TreeModelClassification m = ensembleModel.getTreeModelClassification(i);
TreeNodeClassification match = m.findMatchingNode(record);
String majorityClassName = match.getMajorityClassName();
final float[] nodeClassProbs = match.getTargetDistribution();
double instancesInNode = 0;
for (int c = 0; c < nodeClassProbs.length; c++) {
instancesInNode += nodeClassProbs[c];
}
for (int c = 0; c < classProbabilities.length; c++) {
classProbabilities[c] += nodeClassProbs[c] / instancesInNode;
}
counter.add(majorityClassName);
nrValidModels += 1;
}
}
String bestValue = counter.getMostFrequent();
int index = 0;
if (bestValue == null) {
assert nrValidModels == 0;
Arrays.fill(result, DataType.getMissingCell());
index = size - 1;
} else {
// result[index++] = m_targetValueMap.get(bestValue);
int indexBest = -1;
double probBest = -1;
for (int c = 0; c < classProbabilities.length; c++) {
double prob = classProbabilities[c];
if (prob > probBest) {
probBest = prob;
indexBest = c;
}
}
result[index++] = new StringCell(targetMeta.getValues()[indexBest].getNominalValue());
if (appendConfidence) {
// final int freqValue = counter.getFrequency(bestValue);
// result[index++] = new DoubleCell(freqValue / (double)nrValidModels);
result[index++] = new DoubleCell(probBest);
}
if (appendClassConfidences) {
for (NominalValueRepresentation nomVal : targetMeta.getValues()) {
double prob = classProbabilities[nomVal.getAssignedInteger()] / nrValidModels;
result[index++] = new DoubleCell(prob);
}
}
}
if (appendModelCount) {
result[index++] = new IntCell(nrValidModels);
}
return result;
}
use of org.knime.core.data.def.StringCell in project knime-core by knime.
the class RuleEngineNodeModel method createRearranger.
private ColumnRearranger createRearranger(final DataTableSpec inSpec, final List<Rule> rules) throws InvalidSettingsException {
ColumnRearranger crea = new ColumnRearranger(inSpec);
String newColName = DataTableSpec.getUniqueColumnName(inSpec, m_settings.getNewColName());
final int defaultLabelColumnIndex;
if (m_settings.getDefaultLabelIsColumn()) {
if (m_settings.getDefaultLabel().length() < 3) {
throw new InvalidSettingsException("Default label is not a column reference");
}
if (!m_settings.getDefaultLabel().startsWith("$") || !m_settings.getDefaultLabel().endsWith("$")) {
throw new InvalidSettingsException("Column references in default label must be enclosed in $");
}
String colRef = m_settings.getDefaultLabel().substring(1, m_settings.getDefaultLabel().length() - 1);
defaultLabelColumnIndex = inSpec.findColumnIndex(colRef);
if (defaultLabelColumnIndex == -1) {
throw new InvalidSettingsException("Column '" + m_settings.getDefaultLabel() + "' for default label does not exist in input table");
}
} else {
defaultLabelColumnIndex = -1;
}
// determine output type
List<DataType> types = new ArrayList<DataType>();
// add outcome column types
for (Rule r : rules) {
if (r.getOutcome() instanceof ColumnReference) {
types.add(((ColumnReference) r.getOutcome()).spec.getType());
} else if (r.getOutcome() instanceof Double) {
types.add(DoubleCell.TYPE);
} else if (r.getOutcome() instanceof Integer) {
types.add(IntCell.TYPE);
} else if (r.getOutcome().toString().length() > 0) {
types.add(StringCell.TYPE);
}
}
if (defaultLabelColumnIndex >= 0) {
types.add(inSpec.getColumnSpec(defaultLabelColumnIndex).getType());
} else if (m_settings.getDefaultLabel().length() > 0) {
try {
Integer.parseInt(m_settings.getDefaultLabel());
types.add(IntCell.TYPE);
} catch (NumberFormatException ex) {
try {
Double.parseDouble(m_settings.getDefaultLabel());
types.add(DoubleCell.TYPE);
} catch (NumberFormatException ex1) {
types.add(StringCell.TYPE);
}
}
}
final DataType outType;
if (types.size() > 0) {
DataType temp = types.get(0);
for (int i = 1; i < types.size(); i++) {
temp = DataType.getCommonSuperType(temp, types.get(i));
}
if ((temp.getValueClasses().size() == 1) && temp.getValueClasses().contains(DataValue.class)) {
// a non-native type, we replace it with string
temp = StringCell.TYPE;
}
outType = temp;
} else {
outType = StringCell.TYPE;
}
DataColumnSpec cs = new DataColumnSpecCreator(newColName, outType).createSpec();
crea.append(new SingleCellFactory(cs) {
@Override
public DataCell getCell(final DataRow row) {
for (Rule r : rules) {
if (r.matches(row)) {
Object outcome = r.getOutcome();
if (outcome instanceof ColumnReference) {
DataCell cell = row.getCell(((ColumnReference) outcome).index);
if (outType.equals(StringCell.TYPE) && !cell.isMissing() && !cell.getType().equals(StringCell.TYPE)) {
return new StringCell(cell.toString());
} else {
return cell;
}
} else if (outType.equals(IntCell.TYPE)) {
return new IntCell((Integer) outcome);
} else if (outType.equals(DoubleCell.TYPE)) {
return new DoubleCell((Double) outcome);
} else {
return new StringCell(outcome.toString());
}
}
}
if (defaultLabelColumnIndex >= 0) {
DataCell cell = row.getCell(defaultLabelColumnIndex);
if (outType.equals(StringCell.TYPE) && !cell.getType().equals(StringCell.TYPE)) {
return new StringCell(cell.toString());
} else {
return cell;
}
} else if (m_settings.getDefaultLabel().length() > 0) {
String l = m_settings.getDefaultLabel();
if (outType.equals(StringCell.TYPE)) {
return new StringCell(l);
}
try {
int i = Integer.parseInt(l);
return new IntCell(i);
} catch (NumberFormatException ex) {
try {
double d = Double.parseDouble(l);
return new DoubleCell(d);
} catch (NumberFormatException ex1) {
return new StringCell(l);
}
}
} else {
return DataType.getMissingCell();
}
}
});
return crea;
}
use of org.knime.core.data.def.StringCell in project knime-core by knime.
the class LogisticRegressionContent method createTablePortObject.
/**
* Creates a BufferedDataTable with the
* @param exec The execution context
* @return a port object
*/
public BufferedDataTable createTablePortObject(final ExecutionContext exec) {
DataTableSpec tableOutSpec = new DataTableSpec("Coefficients and Statistics", new String[] { "Logit", "Variable", "Coeff.", "Std. Err.", "z-score", "P>|z|" }, new DataType[] { StringCell.TYPE, StringCell.TYPE, DoubleCell.TYPE, DoubleCell.TYPE, DoubleCell.TYPE, DoubleCell.TYPE });
BufferedDataContainer dc = exec.createDataContainer(tableOutSpec);
List<DataCell> logits = this.getLogits();
List<String> parameters = this.getParameters();
int c = 0;
for (DataCell logit : logits) {
Map<String, Double> coefficients = this.getCoefficients(logit);
Map<String, Double> stdErrs = this.getStandardErrors(logit);
Map<String, Double> zScores = this.getZScores(logit);
Map<String, Double> pValues = this.getPValues(logit);
for (String parameter : parameters) {
List<DataCell> cells = new ArrayList<DataCell>();
cells.add(new StringCell(logit.toString()));
cells.add(new StringCell(parameter));
cells.add(new DoubleCell(coefficients.get(parameter)));
cells.add(new DoubleCell(stdErrs.get(parameter)));
cells.add(new DoubleCell(zScores.get(parameter)));
cells.add(new DoubleCell(pValues.get(parameter)));
c++;
dc.addRowToTable(new DefaultRow("Row" + c, cells));
}
List<DataCell> cells = new ArrayList<DataCell>();
cells.add(new StringCell(logit.toString()));
cells.add(new StringCell("Constant"));
cells.add(new DoubleCell(this.getIntercept(logit)));
cells.add(new DoubleCell(this.getInterceptStdErr(logit)));
cells.add(new DoubleCell(this.getInterceptZScore(logit)));
cells.add(new DoubleCell(this.getInterceptPValue(logit)));
c++;
dc.addRowToTable(new DefaultRow("Row" + c, cells));
}
dc.close();
return dc.getTable();
}
use of org.knime.core.data.def.StringCell in project knime-core by knime.
the class LogisticRegressionContent method createTablePortObject.
/**
* Creates a BufferedDataTable with the
* @param exec The execution context
* @return a port object
*/
public BufferedDataTable createTablePortObject(final ExecutionContext exec) {
DataTableSpec tableOutSpec = new DataTableSpec("Coefficients and Statistics", new String[] { "Logit", "Variable", "Coeff.", "Std. Err.", "z-score", "P>|z|" }, new DataType[] { StringCell.TYPE, StringCell.TYPE, DoubleCell.TYPE, DoubleCell.TYPE, DoubleCell.TYPE, DoubleCell.TYPE });
BufferedDataContainer dc = exec.createDataContainer(tableOutSpec);
List<DataCell> logits = this.getLogits();
List<String> parameters = this.getParameters();
int c = 0;
for (DataCell logit : logits) {
Map<String, Double> coefficients = this.getCoefficients(logit);
Map<String, Double> stdErrs = this.getStandardErrors(logit);
Map<String, Double> zScores = this.getZScores(logit);
Map<String, Double> pValues = this.getPValues(logit);
for (String parameter : parameters) {
List<DataCell> cells = new ArrayList<DataCell>();
cells.add(new StringCell(logit.toString()));
cells.add(new StringCell(parameter));
cells.add(new DoubleCell(coefficients.get(parameter)));
cells.add(new DoubleCell(stdErrs.get(parameter)));
cells.add(new DoubleCell(zScores.get(parameter)));
cells.add(new DoubleCell(pValues.get(parameter)));
c++;
dc.addRowToTable(new DefaultRow("Row" + c, cells));
}
List<DataCell> cells = new ArrayList<DataCell>();
cells.add(new StringCell(logit.toString()));
cells.add(new StringCell("Constant"));
cells.add(new DoubleCell(this.getIntercept(logit)));
cells.add(new DoubleCell(this.getInterceptStdErr(logit)));
cells.add(new DoubleCell(this.getInterceptZScore(logit)));
cells.add(new DoubleCell(this.getInterceptPValue(logit)));
c++;
dc.addRowToTable(new DefaultRow("Row" + c, cells));
}
dc.close();
return dc.getTable();
}
use of org.knime.core.data.def.StringCell in project knime-core by knime.
the class AbstractTreeEnsembleModel method createLearnAttributeRow.
public DataRow createLearnAttributeRow(final DataRow learnRow, final DataTableSpec learnSpec) {
final TreeType type = getType();
final DataCell c = learnRow.getCell(0);
final int nrAttributes = getMetaData().getNrAttributes();
switch(type) {
case Ordinary:
return learnRow;
case BitVector:
if (c.isMissing()) {
return null;
}
BitVectorValue bv = (BitVectorValue) c;
final long length = bv.length();
if (length != nrAttributes) {
// TODO indicate error message
return null;
}
DataCell trueCell = new StringCell("1");
DataCell falseCell = new StringCell("0");
DataCell[] cells = new DataCell[nrAttributes];
for (int i = 0; i < nrAttributes; i++) {
cells[i] = bv.get(i) ? trueCell : falseCell;
}
return new DefaultRow(learnRow.getKey(), cells);
case ByteVector:
if (c.isMissing()) {
return null;
}
ByteVectorValue byteVector = (ByteVectorValue) c;
final long bvLength = byteVector.length();
if (bvLength != nrAttributes) {
return null;
}
DataCell[] bvCells = new DataCell[nrAttributes];
for (int i = 0; i < nrAttributes; i++) {
bvCells[i] = new IntCell(byteVector.get(i));
}
return new DefaultRow(learnRow.getKey(), bvCells);
case DoubleVector:
if (c.isMissing()) {
return null;
}
DoubleVectorValue doubleVector = (DoubleVectorValue) c;
final int dvLength = doubleVector.getLength();
if (dvLength != nrAttributes) {
return null;
}
DataCell[] dvCells = new DataCell[nrAttributes];
for (int i = 0; i < nrAttributes; i++) {
dvCells[i] = new DoubleCell(doubleVector.getValue(i));
}
return new DefaultRow(learnRow.getKey(), dvCells);
default:
throw new IllegalStateException("Type unknown (not implemented): " + type);
}
}
Aggregations