use of org.knime.core.data.DataColumnDomainCreator in project knime-core by knime.
the class PMMLRuleEditorNodeModel method createRearranger.
/**
* Creates the {@link ColumnRearranger} that can compute the new column.
*
* @param tableSpec The spec of the input table.
* @param ruleSet The {@link RuleSet} xml object where the rules should be added.
* @param parser The parser for the rules.
* @return The {@link ColumnRearranger}.
* @throws ParseException Problem during parsing.
* @throws InvalidSettingsException if settings are invalid
*/
private ColumnRearranger createRearranger(final DataTableSpec tableSpec, final RuleSet ruleSet, final PMMLRuleParser parser) throws ParseException, InvalidSettingsException {
if (m_settings.isAppendColumn() && m_settings.getNewColName().isEmpty()) {
throw new InvalidSettingsException("No name for prediction column provided");
}
Set<String> outcomes = new LinkedHashSet<String>();
List<DataType> outcomeTypes = new ArrayList<DataType>();
int line = 0;
final List<Pair<PMMLPredicate, Expression>> rules = new ArrayList<Pair<PMMLPredicate, Expression>>();
for (String ruleText : m_settings.rules()) {
++line;
if (RuleSupport.isComment(ruleText)) {
continue;
}
try {
ParseState state = new ParseState(ruleText);
PMMLPredicate expression = parser.parseBooleanExpression(state);
SimpleRule simpleRule = ruleSet.addNewSimpleRule();
setCondition(simpleRule, expression);
state.skipWS();
state.consumeText("=>");
state.skipWS();
Expression outcome = parser.parseOutcomeOperand(state, null);
// Only constants are allowed in the outcomes.
assert outcome.isConstant() : outcome;
rules.add(new Pair<PMMLPredicate, Expression>(expression, outcome));
outcomeTypes.add(outcome.getOutputType());
simpleRule.setScore(outcome.toString());
// simpleRule.setConfidence(confidenceForRule(simpleRule, line, ruleText));
simpleRule.setWeight(weightForRule(simpleRule, line, ruleText));
outcomes.add(simpleRule.getScore());
} catch (ParseException e) {
throw Util.addContext(e, ruleText, line);
}
}
DataType outcomeType = RuleEngineNodeModel.computeOutputType(outcomeTypes, true);
ColumnRearranger rearranger = new ColumnRearranger(tableSpec);
DataColumnSpecCreator specProto = new DataColumnSpecCreator(m_settings.isAppendColumn() ? DataTableSpec.getUniqueColumnName(tableSpec, m_settings.getNewColName()) : m_settings.getReplaceColumn(), outcomeType);
specProto.setDomain(new DataColumnDomainCreator(toCells(outcomes, outcomeType)).createDomain());
SingleCellFactory cellFactory = new SingleCellFactory(true, specProto.createSpec()) {
@Override
public DataCell getCell(final DataRow row) {
for (Pair<PMMLPredicate, Expression> pair : rules) {
if (pair.getFirst().evaluate(row, tableSpec) == Boolean.TRUE) {
return pair.getSecond().evaluate(row, null).getValue();
}
}
return DataType.getMissingCell();
}
};
if (m_settings.isAppendColumn()) {
rearranger.append(cellFactory);
} else {
rearranger.replace(cellFactory, m_settings.getReplaceColumn());
}
return rearranger;
}
use of org.knime.core.data.DataColumnDomainCreator in project knime-core by knime.
the class PMMLRuleSetPredictorNodeModel method createRearranger.
/**
* Constructs the {@link ColumnRearranger} for computing the new columns.
*
* @param obj The {@link PMMLPortObject} of the preprocessing model.
* @param spec The {@link DataTableSpec} of the table.
* @param replaceColumn Should replace the {@code outputColumnName}?
* @param outputColumnName The output column name (which might be an existing).
* @param addConfidence Should add the confidence values to a column?
* @param confidenceColumnName The name of the confidence column.
* @param validationColumnIdx Index of the validation column, {@code -1} if not specified.
* @param processConcurrently Should be {@code false} when the statistics are to be computed.
* @return The {@link ColumnRearranger} computing the result.
* @throws InvalidSettingsException Problem with rules.
*/
private static ColumnRearranger createRearranger(final PMMLPortObject obj, final DataTableSpec spec, final boolean replaceColumn, final String outputColumnName, final boolean addConfidence, final String confidenceColumnName, final int validationColumnIdx, final boolean processConcurrently) throws InvalidSettingsException {
List<Node> models = obj.getPMMLValue().getModels(PMMLModelType.RuleSetModel);
if (models.size() != 1) {
throw new InvalidSettingsException("Expected exactly on RuleSetModel, but got: " + models.size());
}
final PMMLRuleTranslator translator = new PMMLRuleTranslator();
obj.initializeModelTranslator(translator);
if (!translator.isScorable()) {
throw new UnsupportedOperationException("The model is not scorable.");
}
final List<PMMLRuleTranslator.Rule> rules = translator.getRules();
ColumnRearranger ret = new ColumnRearranger(spec);
final List<DataColumnSpec> targetCols = obj.getSpec().getTargetCols();
final DataType dataType = targetCols.isEmpty() ? StringCell.TYPE : targetCols.get(0).getType();
DataColumnSpecCreator specCreator = new DataColumnSpecCreator(outputColumnName, dataType);
Set<DataCell> outcomes = new LinkedHashSet<>();
for (Rule rule : rules) {
DataCell outcome;
if (dataType.equals(BooleanCell.TYPE)) {
outcome = BooleanCellFactory.create(rule.getOutcome());
} else if (dataType.equals(StringCell.TYPE)) {
outcome = new StringCell(rule.getOutcome());
} else if (dataType.equals(DoubleCell.TYPE)) {
try {
outcome = new DoubleCell(Double.parseDouble(rule.getOutcome()));
} catch (NumberFormatException e) {
// ignore
continue;
}
} else if (dataType.equals(IntCell.TYPE)) {
try {
outcome = new IntCell(Integer.parseInt(rule.getOutcome()));
} catch (NumberFormatException e) {
// ignore
continue;
}
} else if (dataType.equals(LongCell.TYPE)) {
try {
outcome = new LongCell(Long.parseLong(rule.getOutcome()));
} catch (NumberFormatException e) {
// ignore
continue;
}
} else {
throw new UnsupportedOperationException("Unknown outcome type: " + dataType);
}
outcomes.add(outcome);
}
specCreator.setDomain(new DataColumnDomainCreator(outcomes).createDomain());
DataColumnSpec colSpec = specCreator.createSpec();
final RuleSelectionMethod ruleSelectionMethod = translator.getSelectionMethodList().get(0);
final String defaultScore = translator.getDefaultScore();
final Double defaultConfidence = translator.getDefaultConfidence();
final DataColumnSpec[] specs;
if (addConfidence) {
specs = new DataColumnSpec[] { new DataColumnSpecCreator(DataTableSpec.getUniqueColumnName(ret.createSpec(), confidenceColumnName), DoubleCell.TYPE).createSpec(), colSpec };
} else {
specs = new DataColumnSpec[] { colSpec };
}
final int oldColumnIndex = replaceColumn ? ret.indexOf(outputColumnName) : -1;
ret.append(new AbstractCellFactory(processConcurrently, specs) {
private final List<String> m_values;
{
Map<String, List<String>> dd = translator.getDataDictionary();
m_values = dd.get(targetCols.get(0).getName());
}
/**
* {@inheritDoc}
*/
@Override
public DataCell[] getCells(final DataRow row) {
// See http://www.dmg.org/v4-1/RuleSet.html#Rule
switch(ruleSelectionMethod.getCriterion().intValue()) {
case RuleSelectionMethod.Criterion.INT_FIRST_HIT:
{
Pair<DataCell, Double> resultAndConfidence = selectFirstHit(row);
return toCells(resultAndConfidence);
}
case RuleSelectionMethod.Criterion.INT_WEIGHTED_MAX:
{
Pair<DataCell, Double> resultAndConfidence = selectWeightedMax(row);
return toCells(resultAndConfidence);
}
case RuleSelectionMethod.Criterion.INT_WEIGHTED_SUM:
{
Pair<DataCell, Double> resultAndConfidence = selectWeightedSum(row);
return toCells(resultAndConfidence);
}
default:
throw new UnsupportedOperationException(ruleSelectionMethod.getCriterion().toString());
}
}
/**
* Converts the pair to a {@link DataCell} array.
*
* @param resultAndConfidence The {@link Pair}.
* @return The result and possibly the confidence.
*/
private DataCell[] toCells(final Pair<DataCell, Double> resultAndConfidence) {
if (!addConfidence) {
return new DataCell[] { resultAndConfidence.getFirst() };
}
if (resultAndConfidence.getSecond() == null) {
return new DataCell[] { DataType.getMissingCell(), resultAndConfidence.getFirst() };
}
return new DataCell[] { new DoubleCell(resultAndConfidence.getSecond()), resultAndConfidence.getFirst() };
}
/**
* Computes the result and the confidence using the weighted sum method.
*
* @param row A {@link DataRow}
* @return The result and the confidence.
*/
private Pair<DataCell, Double> selectWeightedSum(final DataRow row) {
final Map<String, Double> scoreToSumWeight = new LinkedHashMap<String, Double>();
for (String val : m_values) {
scoreToSumWeight.put(val, 0.0);
}
int matchedRuleCount = 0;
for (final PMMLRuleTranslator.Rule rule : rules) {
if (rule.getCondition().evaluate(row, spec) == Boolean.TRUE) {
++matchedRuleCount;
Double sumWeight = scoreToSumWeight.get(rule.getOutcome());
if (sumWeight == null) {
throw new IllegalStateException("The score value: " + rule.getOutcome() + " is not in the data dictionary.");
}
final Double wRaw = rule.getWeight();
final double w = wRaw == null ? 0.0 : wRaw.doubleValue();
scoreToSumWeight.put(rule.getOutcome(), sumWeight + w);
}
}
double maxSumWeight = Double.NEGATIVE_INFINITY;
String bestScore = null;
for (Entry<String, Double> entry : scoreToSumWeight.entrySet()) {
final double d = entry.getValue().doubleValue();
if (d > maxSumWeight) {
maxSumWeight = d;
bestScore = entry.getKey();
}
}
if (bestScore == null || matchedRuleCount == 0) {
return pair(result(defaultScore), defaultConfidence);
}
return pair(result(bestScore), maxSumWeight / matchedRuleCount);
}
/**
* Helper method to create {@link Pair}s.
*
* @param f The first element.
* @param s The second element.
* @return The new pair.
*/
private <F, S> Pair<F, S> pair(final F f, final S s) {
return new Pair<F, S>(f, s);
}
/**
* Computes the result and the confidence using the weighted max method.
*
* @param row A {@link DataRow}
* @return The result and the confidence.
*/
private Pair<DataCell, Double> selectWeightedMax(final DataRow row) {
double maxWeight = Double.NEGATIVE_INFINITY;
PMMLRuleTranslator.Rule bestRule = null;
for (final PMMLRuleTranslator.Rule rule : rules) {
if (rule.getCondition().evaluate(row, spec) == Boolean.TRUE) {
if (rule.getWeight() > maxWeight) {
maxWeight = rule.getWeight();
bestRule = rule;
}
}
}
if (bestRule == null) {
return pair(result(defaultScore), defaultConfidence);
}
bestRule.setRecordCount(bestRule.getRecordCount() + 1);
DataCell result = result(bestRule);
if (validationColumnIdx >= 0) {
if (row.getCell(validationColumnIdx).equals(result)) {
bestRule.setNbCorrect(bestRule.getNbCorrect() + 1);
}
}
Double confidence = bestRule.getConfidence();
return pair(result, confidence == null ? defaultConfidence : confidence);
}
/**
* Selects the outcome of the rule and converts it to the proper outcome type.
*
* @param rule A {@link Rule}.
* @return The {@link DataCell} representing the result. (May be missing.)
*/
private DataCell result(final PMMLRuleTranslator.Rule rule) {
String outcome = rule.getOutcome();
return result(outcome);
}
/**
* Constructs the {@link DataCell} from its {@link String} representation ({@code outcome}) and its type.
*
* @param dataType The expected {@link DataType}
* @param outcome The {@link String} representation.
* @return The {@link DataCell}.
*/
private DataCell result(final String outcome) {
if (outcome == null) {
return DataType.getMissingCell();
}
try {
if (dataType.isCompatible(BooleanValue.class)) {
return BooleanCellFactory.create(outcome);
}
if (IntCell.TYPE.isASuperTypeOf(dataType)) {
return new IntCell(Integer.parseInt(outcome));
}
if (LongCell.TYPE.isASuperTypeOf(dataType)) {
return new LongCell(Long.parseLong(outcome));
}
if (DoubleCell.TYPE.isASuperTypeOf(dataType)) {
return new DoubleCell(Double.parseDouble(outcome));
}
return new StringCell(outcome);
} catch (NumberFormatException e) {
return new MissingCell(outcome + "\n" + e.getMessage());
}
}
/**
* Selects the first rule that matches and computes the confidence and result for the {@code row}.
*
* @param row A {@link DataRow}.
* @return The result and the confidence.
*/
private Pair<DataCell, Double> selectFirstHit(final DataRow row) {
for (final PMMLRuleTranslator.Rule rule : rules) {
Boolean eval = rule.getCondition().evaluate(row, spec);
if (eval == Boolean.TRUE) {
rule.setRecordCount(rule.getRecordCount() + 1);
DataCell result = result(rule);
if (validationColumnIdx >= 0) {
if (row.getCell(validationColumnIdx).equals(result)) {
rule.setNbCorrect(rule.getNbCorrect() + 1);
}
}
Double confidence = rule.getConfidence();
return pair(result, confidence == null ? defaultConfidence : confidence);
}
}
return pair(result(defaultScore), defaultConfidence);
}
/**
* {@inheritDoc}
*/
@Override
public void afterProcessing() {
super.afterProcessing();
obj.getPMMLValue();
RuleSetModel ruleSet = translator.getOriginalRuleSetModel();
assert rules.size() == ruleSet.getRuleSet().getSimpleRuleList().size() + ruleSet.getRuleSet().getCompoundRuleList().size();
if (ruleSet.getRuleSet().getSimpleRuleList().size() == rules.size()) {
for (int i = 0; i < rules.size(); ++i) {
Rule rule = rules.get(i);
final SimpleRule simpleRuleArray = ruleSet.getRuleSet().getSimpleRuleArray(i);
synchronized (simpleRuleArray) /*synchronized fixes AP-6766 */
{
simpleRuleArray.setRecordCount(rule.getRecordCount());
if (validationColumnIdx >= 0) {
simpleRuleArray.setNbCorrect(rule.getNbCorrect());
} else if (simpleRuleArray.isSetNbCorrect()) {
simpleRuleArray.unsetNbCorrect();
}
}
}
}
}
});
if (replaceColumn) {
ret.remove(outputColumnName);
ret.move(ret.getColumnCount() - 1 - (addConfidence ? 1 : 0), oldColumnIndex);
}
return ret;
}
use of org.knime.core.data.DataColumnDomainCreator in project knime-core by knime.
the class StatisticsTable method calculateAllMoments.
/**
* Calculates <b>all the statistical moments in one pass </b>. After the
* call of this operation, the statistical moments can be obtained very fast
* from all the other methods.
*
* @param rowCount Row count of table for progress, may be NaN if unknown.
* @param exec object to check with if user canceled the operation
* @throws CanceledExecutionException if user canceled
* @throws IllegalArgumentException if rowCount argument < 0
*/
protected void calculateAllMoments(final double rowCount, final ExecutionMonitor exec) throws CanceledExecutionException {
if (rowCount < 0.0) {
throw new IllegalArgumentException("rowCount argument must not < 0: " + rowCount);
}
DataTableSpec origSpec = m_table.getDataTableSpec();
int numOfCols = origSpec.getNumColumns();
// the number of non-missing cells in each column
int[] validCount = new int[numOfCols];
double[] sumsquare = new double[numOfCols];
final DataValueComparator[] comp = new DataValueComparator[numOfCols];
for (int i = 0; i < numOfCols; i++) {
sumsquare[i] = 0.0;
validCount[i] = 0;
comp[i] = origSpec.getColumnSpec(i).getType().getComparator();
assert comp[i] != null;
}
int nrRows = 0;
for (RowIterator rowIt = m_table.iterator(); rowIt.hasNext(); nrRows++) {
DataRow row = rowIt.next();
if (exec != null) {
double prog = Double.isNaN(rowCount) ? 0.0 : nrRows / rowCount;
exec.setProgress(prog, "Calculating statistics, processing row " + (nrRows + 1) + " (\"" + row.getKey() + "\")");
// throws exception if user canceled
exec.checkCanceled();
}
for (int c = 0; c < numOfCols; c++) {
final DataCell cell = row.getCell(c);
if (!(cell.isMissing())) {
// keep the min and max for each column
if ((m_minValues[c] == null) || (comp[c].compare(cell, m_minValues[c]) < 0)) {
m_minValues[c] = cell;
}
if ((m_maxValues[c] == null) || (comp[c].compare(m_maxValues[c], cell) < 0)) {
m_maxValues[c] = cell;
}
// for double columns we calc the sum (for the mean calc)
DataType type = origSpec.getColumnSpec(c).getType();
if (type.isCompatible(DoubleValue.class)) {
double d = ((DoubleValue) cell).getDoubleValue();
if (Double.isNaN(m_sum[c])) {
m_sum[c] = d;
} else {
m_sum[c] += d;
}
sumsquare[c] += d * d;
validCount[c]++;
}
} else {
m_missingValueCnt[c]++;
}
}
calculateMomentInSubClass(row);
}
m_nrRows = nrRows;
for (int j = 0; j < numOfCols; j++) {
// missing values
if (validCount[j] == 0 || m_minValues[j] == null) {
DataCell mc = DataType.getMissingCell();
m_minValues[j] = mc;
m_maxValues[j] = mc;
m_meanValues[j] = Double.NaN;
m_varianceValues[j] = Double.NaN;
} else {
m_meanValues[j] = m_sum[j] / validCount[j];
if (validCount[j] > 1) {
m_varianceValues[j] = (sumsquare[j] - ((m_sum[j] * m_sum[j]) / validCount[j])) / (validCount[j] - 1);
} else {
m_varianceValues[j] = 0.0;
}
// round-off errors resulting in negative variance values
if (m_varianceValues[j] < 0.0 && m_varianceValues[j] > -1.0E8) {
m_varianceValues[j] = 0.0;
}
assert m_varianceValues[j] >= 0.0 : "Variance cannot be negative (column \"" + origSpec.getColumnSpec(j).getName() + "\": " + m_varianceValues[j];
}
}
// compute resulting table spec
int nrCols = m_table.getDataTableSpec().getNumColumns();
DataColumnSpec[] cSpec = new DataColumnSpec[nrCols];
for (int c = 0; c < nrCols; c++) {
DataColumnSpec s = m_table.getDataTableSpec().getColumnSpec(c);
// we create domains with our bounds.
Set<DataCell> values = (s.getDomain() == null ? null : s.getDomain().getValues());
DataColumnDomain newDomain = new DataColumnDomainCreator(values, (m_minValues[c] == null || m_minValues[c].isMissing()) ? null : m_minValues[c], (m_maxValues[c] == null || m_maxValues[c].isMissing()) ? null : m_maxValues[c]).createDomain();
DataColumnSpecCreator creator = new DataColumnSpecCreator(s);
creator.setDomain(newDomain);
cSpec[c] = creator.createSpec();
}
m_tSpec = new DataTableSpec(cSpec);
}
use of org.knime.core.data.DataColumnDomainCreator in project knime-core by knime.
the class DecTreePredictorNodeModel method createOutTableSpec.
private DataTableSpec createOutTableSpec(final PortObjectSpec[] inSpecs) {
LinkedList<DataCell> predValues = null;
if (m_showDistribution.getBooleanValue()) {
predValues = getPredictionValues((PMMLPortObjectSpec) inSpecs[INMODELPORT]);
if (predValues == null) {
// no out spec can be determined
return null;
}
}
int numCols = (predValues == null ? 0 : predValues.size()) + 1;
DataTableSpec inSpec = (DataTableSpec) inSpecs[INDATAPORT];
UniqueNameGenerator nameGenerator = new UniqueNameGenerator(inSpec);
DataColumnSpec[] newCols = new DataColumnSpec[numCols];
/* Set bar renderer and domain [0,1] as default for the double cells
* containing the distribution */
// DataColumnProperties propsRendering = new DataColumnProperties(
// Collections.singletonMap(
// DataValueRenderer.PROPERTY_PREFERRED_RENDERER,
// DoubleBarRenderer.DESCRIPTION));
DataColumnDomain domain = new DataColumnDomainCreator(new DoubleCell(0.0), new DoubleCell(1.0)).createDomain();
// add all distribution columns
for (int i = 0; i < numCols - 1; i++) {
DataColumnSpecCreator colSpecCreator = nameGenerator.newCreator(predValues.get(i).toString(), DoubleCell.TYPE);
// colSpecCreator.setProperties(propsRendering);
colSpecCreator.setDomain(domain);
newCols[i] = colSpecCreator.createSpec();
}
// add the prediction column
newCols[numCols - 1] = nameGenerator.newColumn("Prediction (DecTree)", StringCell.TYPE);
DataTableSpec newColSpec = new DataTableSpec(newCols);
return new DataTableSpec(inSpec, newColSpec);
}
use of org.knime.core.data.DataColumnDomainCreator in project knime-core by knime.
the class AutoBinner method calcDomainBoundsIfNeccessary.
/**
* Determines the per column min/max values of the given data if not already
* present in the domain.
* @param data the data
* @param exec the execution context
* @param recalcValuesFor The columns
* @return The data with extended domain information
* @throws InvalidSettingsException
* @throws CanceledExecutionException
*/
public BufferedDataTable calcDomainBoundsIfNeccessary(final BufferedDataTable data, final ExecutionContext exec, final List<String> recalcValuesFor) throws InvalidSettingsException, CanceledExecutionException {
if (null == recalcValuesFor || recalcValuesFor.isEmpty()) {
return data;
}
List<Integer> valuesI = new ArrayList<Integer>();
for (String colName : recalcValuesFor) {
DataColumnSpec colSpec = data.getDataTableSpec().getColumnSpec(colName);
if (!colSpec.getType().isCompatible(DoubleValue.class)) {
throw new InvalidSettingsException("Can only process numeric " + "data. The column \"" + colSpec.getName() + "\" is not numeric.");
}
if (recalcValuesFor.contains(colName) && !colSpec.getDomain().hasBounds()) {
valuesI.add(data.getDataTableSpec().findColumnIndex(colName));
}
}
if (valuesI.isEmpty()) {
return data;
}
Map<Integer, Double> min = new HashMap<Integer, Double>();
Map<Integer, Double> max = new HashMap<Integer, Double>();
for (int col : valuesI) {
min.put(col, Double.MAX_VALUE);
max.put(col, Double.MIN_VALUE);
}
int c = 0;
for (DataRow row : data) {
c++;
exec.checkCanceled();
exec.setProgress(c / (double) data.getRowCount());
for (int col : valuesI) {
double val = ((DoubleValue) row.getCell(col)).getDoubleValue();
if (min.get(col) > val) {
min.put(col, val);
}
if (max.get(col) < val) {
min.put(col, val);
}
}
}
List<DataColumnSpec> newColSpecList = new ArrayList<DataColumnSpec>();
int cc = 0;
for (DataColumnSpec columnSpec : data.getDataTableSpec()) {
if (recalcValuesFor.contains(columnSpec.getName())) {
DataColumnSpecCreator specCreator = new DataColumnSpecCreator(columnSpec);
DataColumnDomainCreator domainCreator = new DataColumnDomainCreator(new DoubleCell(min.get(cc)), new DoubleCell(max.get(cc)));
specCreator.setDomain(domainCreator.createDomain());
DataColumnSpec newColSpec = specCreator.createSpec();
newColSpecList.add(newColSpec);
} else {
newColSpecList.add(columnSpec);
}
cc++;
}
DataTableSpec spec = new DataTableSpec(newColSpecList.toArray(new DataColumnSpec[0]));
BufferedDataTable newDataTable = exec.createSpecReplacerTable(data, spec);
return newDataTable;
}
Aggregations