use of org.knime.base.node.mine.decisiontree2.PMMLSimplePredicate in project knime-core by knime.
the class DecisionTreeLearnerNodeModel method buildTree.
/**
* Recursively induces the decision tree.
*
* @param table the {@link InMemoryTable} representing the data for this
* node to determine the split and after that perform
* partitioning
* @param exec the execution context for progress information
* @param depth the current recursion depth
*/
private DecisionTreeNode buildTree(final InMemoryTable table, final ExecutionContext exec, final int depth, final SplitQualityMeasure splitQualityMeasure, final ParallelProcessing parallelProcessing) throws CanceledExecutionException, IllegalAccessException {
exec.checkCanceled();
// derive this node's id from the counter
int nodeId = m_counter.getAndIncrement();
DataCell majorityClass = table.getMajorityClassAsCell();
LinkedHashMap<DataCell, Double> frequencies = table.getClassFrequencies();
// if the distribution allows for a leaf
if (table.isPureEnough()) {
// free memory
table.freeUnderlyingDataRows();
double value = m_finishedCounter.incrementAndGet(table.getSumOfWeights());
exec.setProgress(value / m_alloverRowCount, "Created node with id " + nodeId + " at level " + depth);
return new DecisionTreeNodeLeaf(nodeId, majorityClass, frequencies);
} else {
// find the best splits for all attributes
SplitFinder splittFinder = new SplitFinder(table, splitQualityMeasure, m_averageSplitpoint.getBooleanValue(), m_minNumberRecordsPerNode.getIntValue(), m_binaryNominalSplitMode.getBooleanValue(), m_maxNumNominalsForCompleteComputation.getIntValue());
// check for enough memory
checkMemory();
// get the best split among the best attribute splits
Split split = splittFinder.getSplit();
// if no best split could be evaluated, create a leaf node
if (split == null || !split.isValidSplit()) {
table.freeUnderlyingDataRows();
double value = m_finishedCounter.incrementAndGet(table.getSumOfWeights());
exec.setProgress(value / m_alloverRowCount, "Created node with id " + nodeId + " at level " + depth);
return new DecisionTreeNodeLeaf(nodeId, majorityClass, frequencies);
}
// partition the attribute lists according to this split
Partitioner partitioner = new Partitioner(table, split, m_minNumberRecordsPerNode.getIntValue());
if (!partitioner.couldBeUsefulPartitioned()) {
table.freeUnderlyingDataRows();
double value = m_finishedCounter.incrementAndGet(table.getSumOfWeights());
exec.setProgress(value / m_alloverRowCount, "Created node with id " + nodeId + " at level " + depth);
return new DecisionTreeNodeLeaf(nodeId, majorityClass, frequencies);
}
// get the just created partitions
InMemoryTable[] partitionTables = partitioner.getPartitionTables();
// recursively build the child nodes
DecisionTreeNode[] children = new DecisionTreeNode[partitionTables.length];
ArrayList<ParallelBuilding> threads = new ArrayList<ParallelBuilding>();
int i = 0;
for (InMemoryTable partitionTable : partitionTables) {
exec.checkCanceled();
if (partitionTable.getNumberDataRows() * m_numberAttributes < 10000 || !parallelProcessing.isThreadAvailable()) {
children[i] = buildTree(partitionTable, exec, depth + 1, splitQualityMeasure, parallelProcessing);
} else {
String threadName = "Build thread, node: " + nodeId + "." + i;
ParallelBuilding buildThread = new ParallelBuilding(threadName, partitionTable, exec, depth + 1, i, splitQualityMeasure, parallelProcessing);
LOGGER.debug("Start new parallel building thread: " + threadName);
threads.add(buildThread);
buildThread.start();
}
i++;
}
// already assigned to the child array
for (ParallelBuilding buildThread : threads) {
children[buildThread.getThreadIndex()] = buildThread.getResultNode();
exec.checkCanceled();
if (buildThread.getException() != null) {
for (ParallelBuilding buildThread2 : threads) {
buildThread2.stop();
}
throw new RuntimeException(buildThread.getException().getMessage());
}
}
threads.clear();
if (split instanceof SplitContinuous) {
double splitValue = ((SplitContinuous) split).getBestSplitValue();
// return new DecisionTreeNodeSplitContinuous(nodeId,
// majorityClass, frequencies, split
// .getSplitAttributeName(), children, splitValue);
String splitAttribute = split.getSplitAttributeName();
PMMLPredicate[] splitPredicates = new PMMLPredicate[] { new PMMLSimplePredicate(splitAttribute, PMMLOperator.LESS_OR_EQUAL, Double.toString(splitValue)), new PMMLSimplePredicate(splitAttribute, PMMLOperator.GREATER_THAN, Double.toString(splitValue)) };
return new DecisionTreeNodeSplitPMML(nodeId, majorityClass, frequencies, splitAttribute, splitPredicates, children);
} else if (split instanceof SplitNominalNormal) {
// else the attribute is nominal
DataCell[] splitValues = ((SplitNominalNormal) split).getSplitValues();
// return new DecisionTreeNodeSplitNominal(nodeId, majorityClass,
// frequencies, split.getSplitAttributeName(),
// splitValues, children);
int num = children.length;
PMMLPredicate[] splitPredicates = new PMMLPredicate[num];
String splitAttribute = split.getSplitAttributeName();
for (int j = 0; j < num; j++) {
splitPredicates[j] = new PMMLSimplePredicate(splitAttribute, PMMLOperator.EQUAL, splitValues[j].toString());
}
return new DecisionTreeNodeSplitPMML(nodeId, majorityClass, frequencies, splitAttribute, splitPredicates, children);
} else {
// binary nominal
SplitNominalBinary splitNominalBinary = (SplitNominalBinary) split;
DataCell[] splitValues = splitNominalBinary.getSplitValues();
// return new DecisionTreeNodeSplitNominalBinary(nodeId,
// majorityClass, frequencies, split
// .getSplitAttributeName(), splitValues,
// splitNominalBinary.getIntMappingsLeftPartition(),
// splitNominalBinary.getIntMappingsRightPartition(),
// children/* children[0]=left, ..[1] right */);
String splitAttribute = split.getSplitAttributeName();
int[][] indices = new int[][] { splitNominalBinary.getIntMappingsLeftPartition(), splitNominalBinary.getIntMappingsRightPartition() };
PMMLPredicate[] splitPredicates = new PMMLPredicate[2];
for (int j = 0; j < splitPredicates.length; j++) {
PMMLSimpleSetPredicate pred = null;
pred = new PMMLSimpleSetPredicate(splitAttribute, PMMLSetOperator.IS_IN);
pred.setArrayType(PMMLArrayType.STRING);
LinkedHashSet<String> values = new LinkedHashSet<String>();
for (int index : indices[j]) {
values.add(splitValues[index].toString());
}
pred.setValues(values);
splitPredicates[j] = pred;
}
return new DecisionTreeNodeSplitPMML(nodeId, majorityClass, frequencies, splitAttribute, splitPredicates, children);
}
}
}
use of org.knime.base.node.mine.decisiontree2.PMMLSimplePredicate in project knime-core by knime.
the class DecisionTreeLearnerNodeModel2 method buildTree.
/**
* Recursively induces the decision tree.
*
* @param table the {@link InMemoryTable} representing the data for this
* node to determine the split and after that perform
* partitioning
* @param exec the execution context for progress information
* @param depth the current recursion depth
*/
private DecisionTreeNode buildTree(final InMemoryTable table, final ExecutionContext exec, final int depth, final SplitQualityMeasure splitQualityMeasure, final ParallelProcessing parallelProcessing, final int firstSplitCol) throws CanceledExecutionException, IllegalAccessException {
exec.checkCanceled();
// derive this node's id from the counter
int nodeId = m_counter.getAndIncrement();
DataCell majorityClass = table.getMajorityClassAsCell();
LinkedHashMap<DataCell, Double> frequencies = table.getClassFrequencies();
// if the distribution allows for a leaf
if (table.isPureEnough()) {
// free memory
table.freeUnderlyingDataRows();
double value = m_finishedCounter.incrementAndGet(table.getSumOfWeights());
exec.setProgress(value / m_alloverRowCount, "Created node with id " + nodeId + " at level " + depth);
return new DecisionTreeNodeLeaf(nodeId, majorityClass, frequencies);
} else {
Split split = null;
// find best split in specified column for first split
if (depth == 0 && m_useFirstSplitCol.getBooleanValue()) {
if (table.isNominal(firstSplitCol)) {
if (m_binaryNominalSplitMode.getBooleanValue()) {
split = new SplitNominalBinary(table, firstSplitCol, splitQualityMeasure, m_minNumberRecordsPerNode.getIntValue(), m_maxNumNominalsForCompleteComputation.getIntValue());
} else {
split = new SplitNominalNormal(table, firstSplitCol, splitQualityMeasure, m_minNumberRecordsPerNode.getIntValue());
}
} else {
split = new SplitContinuous(table, firstSplitCol, splitQualityMeasure, m_averageSplitpoint.getBooleanValue(), m_minNumberRecordsPerNode.getIntValue());
}
if (Double.isNaN(split.getBestQualityMeasure()) || split.getBestQualityMeasure() == 0.0) {
m_warningMessageSb.append("The specified root split column \"").append(split.getSplitAttributeName()).append("\" does not contain a valid split.");
}
}
if (split == null) {
// no root split column found or selected
// find the best splits for all attributes
SplitFinder splittFinder = new SplitFinder(table, splitQualityMeasure, m_averageSplitpoint.getBooleanValue(), m_minNumberRecordsPerNode.getIntValue(), m_binaryNominalSplitMode.getBooleanValue(), m_maxNumNominalsForCompleteComputation.getIntValue());
// check for enough memory
checkMemory();
// get the best split among the best attribute splits
split = splittFinder.getSplit();
}
// if no best split could be evaluated, create a leaf node
if (split == null || !split.isValidSplit()) {
table.freeUnderlyingDataRows();
double value = m_finishedCounter.incrementAndGet(table.getSumOfWeights());
exec.setProgress(value / m_alloverRowCount, "Created node with id " + nodeId + " at level " + depth);
return new DecisionTreeNodeLeaf(nodeId, majorityClass, frequencies);
}
// partition the attribute lists according to this split
Partitioner partitioner = new Partitioner(table, split, m_minNumberRecordsPerNode.getIntValue());
if (!partitioner.couldBeUsefulPartitioned()) {
table.freeUnderlyingDataRows();
double value = m_finishedCounter.incrementAndGet(table.getSumOfWeights());
exec.setProgress(value / m_alloverRowCount, "Created node with id " + nodeId + " at level " + depth);
return new DecisionTreeNodeLeaf(nodeId, majorityClass, frequencies);
}
// get the just created partitions
InMemoryTable[] partitionTables = partitioner.getPartitionTables();
// recursively build the child nodes
DecisionTreeNode[] children = new DecisionTreeNode[partitionTables.length];
ArrayList<ParallelBuilding> threads = new ArrayList<ParallelBuilding>();
int i = 0;
for (InMemoryTable partitionTable : partitionTables) {
exec.checkCanceled();
if (partitionTable.getNumberDataRows() * m_numberAttributes < 10000 || !parallelProcessing.isThreadAvailable()) {
children[i] = buildTree(partitionTable, exec, depth + 1, splitQualityMeasure, parallelProcessing, firstSplitCol);
} else {
String threadName = "Build thread, node: " + nodeId + "." + i;
ParallelBuilding buildThread = new ParallelBuilding(threadName, partitionTable, exec, depth + 1, i, splitQualityMeasure, parallelProcessing);
LOGGER.debug("Start new parallel building thread: " + threadName);
threads.add(buildThread);
buildThread.start();
}
i++;
}
// already assigned to the child array
for (ParallelBuilding buildThread : threads) {
children[buildThread.getThreadIndex()] = buildThread.getResultNode();
exec.checkCanceled();
if (buildThread.getException() != null) {
for (ParallelBuilding buildThread2 : threads) {
buildThread2.stop();
}
throw new RuntimeException(buildThread.getException().getMessage());
}
}
threads.clear();
if (split instanceof SplitContinuous) {
double splitValue = ((SplitContinuous) split).getBestSplitValue();
// return new DecisionTreeNodeSplitContinuous(nodeId,
// majorityClass, frequencies, split
// .getSplitAttributeName(), children, splitValue);
String splitAttribute = split.getSplitAttributeName();
PMMLPredicate[] splitPredicates = new PMMLPredicate[] { new PMMLSimplePredicate(splitAttribute, PMMLOperator.LESS_OR_EQUAL, Double.toString(splitValue)), new PMMLSimplePredicate(splitAttribute, PMMLOperator.GREATER_THAN, Double.toString(splitValue)) };
return new DecisionTreeNodeSplitPMML(nodeId, majorityClass, frequencies, splitAttribute, splitPredicates, children);
} else if (split instanceof SplitNominalNormal) {
// else the attribute is nominal
DataCell[] splitValues = ((SplitNominalNormal) split).getSplitValues();
// return new DecisionTreeNodeSplitNominal(nodeId, majorityClass,
// frequencies, split.getSplitAttributeName(),
// splitValues, children);
int num = children.length;
PMMLPredicate[] splitPredicates = new PMMLPredicate[num];
String splitAttribute = split.getSplitAttributeName();
for (int j = 0; j < num; j++) {
splitPredicates[j] = new PMMLSimplePredicate(splitAttribute, PMMLOperator.EQUAL, splitValues[j].toString());
}
return new DecisionTreeNodeSplitPMML(nodeId, majorityClass, frequencies, splitAttribute, splitPredicates, children);
} else {
// binary nominal
SplitNominalBinary splitNominalBinary = (SplitNominalBinary) split;
DataCell[] splitValues = splitNominalBinary.getSplitValues();
// return new DecisionTreeNodeSplitNominalBinary(nodeId,
// majorityClass, frequencies, split
// .getSplitAttributeName(), splitValues,
// splitNominalBinary.getIntMappingsLeftPartition(),
// splitNominalBinary.getIntMappingsRightPartition(),
// children/* children[0]=left, ..[1] right */);
String splitAttribute = split.getSplitAttributeName();
int[][] indices = new int[][] { splitNominalBinary.getIntMappingsLeftPartition(), splitNominalBinary.getIntMappingsRightPartition() };
PMMLPredicate[] splitPredicates = new PMMLPredicate[2];
for (int j = 0; j < splitPredicates.length; j++) {
PMMLSimpleSetPredicate pred = null;
pred = new PMMLSimpleSetPredicate(splitAttribute, PMMLSetOperator.IS_IN);
pred.setArrayType(PMMLArrayType.STRING);
LinkedHashSet<String> values = new LinkedHashSet<String>();
for (int index : indices[j]) {
values.add(splitValues[index].toString());
}
pred.setValues(values);
splitPredicates[j] = pred;
}
return new DecisionTreeNodeSplitPMML(nodeId, majorityClass, frequencies, splitAttribute, splitPredicates, children);
}
}
}
use of org.knime.base.node.mine.decisiontree2.PMMLSimplePredicate in project knime-core by knime.
the class TreeNodeNominalConditionTest method testToPMMLPredicate.
/**
* This method tests the {@link TreeNodeNominalCondition#toPMMLPredicate()} method.
*
* @throws Exception
*/
@Test
public void testToPMMLPredicate() throws Exception {
final TreeEnsembleLearnerConfiguration config = new TreeEnsembleLearnerConfiguration(false);
final TestDataGenerator dataGen = new TestDataGenerator(config);
final TreeNominalColumnData col = dataGen.createNominalAttributeColumn("A,A,B,C,C,D", "testcol", 0);
TreeNodeNominalCondition cond = new TreeNodeNominalCondition(col.getMetaData(), 3, false);
PMMLPredicate predicate = cond.toPMMLPredicate();
assertThat(predicate, instanceOf(PMMLSimplePredicate.class));
PMMLSimplePredicate simplePredicate = (PMMLSimplePredicate) predicate;
assertEquals("Wrong operator", PMMLOperator.EQUAL, simplePredicate.getOperator());
assertEquals("Wrong split value", "D", simplePredicate.getThreshold());
cond = new TreeNodeNominalCondition(col.getMetaData(), 0, true);
predicate = cond.toPMMLPredicate();
assertThat(predicate, instanceOf(PMMLCompoundPredicate.class));
PMMLCompoundPredicate compound = (PMMLCompoundPredicate) predicate;
assertEquals("Wrong boolean operator.", PMMLBooleanOperator.OR, compound.getBooleanOperator());
List<PMMLPredicate> preds;
preds = compound.getPredicates();
assertEquals("Wrong number of predicates in compound predicate.", 2, preds.size());
assertThat(preds.get(0), instanceOf(PMMLSimplePredicate.class));
simplePredicate = (PMMLSimplePredicate) preds.get(0);
assertEquals("Wrong operator", PMMLOperator.EQUAL, simplePredicate.getOperator());
assertEquals("Wrong split value", "A", simplePredicate.getThreshold());
assertEquals("Wrong attribute.", col.getMetaData().getAttributeName(), simplePredicate.getSplitAttribute());
assertThat(preds.get(1), instanceOf(PMMLSimplePredicate.class));
simplePredicate = (PMMLSimplePredicate) preds.get(1);
assertEquals("Should be isMissing", PMMLOperator.IS_MISSING, simplePredicate.getOperator());
}
use of org.knime.base.node.mine.decisiontree2.PMMLSimplePredicate in project knime-core by knime.
the class TreeNodeNominalCondition method toPMMLPredicate.
/**
* {@inheritDoc}
*/
@Override
public PMMLPredicate toPMMLPredicate() {
final PMMLSimplePredicate simplePredicate = new PMMLSimplePredicate(getAttributeName(), PMMLOperator.EQUAL, getValue());
if (!acceptsMissings()) {
// return simple predicate if condition rejects missing values
return simplePredicate;
}
// add compound predicate to allow for missing values
final PMMLCompoundPredicate compPredicate = new PMMLCompoundPredicate(PMMLBooleanOperator.OR);
compPredicate.addPredicate(simplePredicate);
final PMMLSimplePredicate missing = new PMMLSimplePredicate();
missing.setSplitAttribute(getAttributeName());
missing.setOperator(PMMLOperator.IS_MISSING);
compPredicate.addPredicate(missing);
return compPredicate;
}
use of org.knime.base.node.mine.decisiontree2.PMMLSimplePredicate in project knime-core by knime.
the class TreeNodeNumericCondition method toPMMLPredicate.
/**
* {@inheritDoc}
*/
@Override
public PMMLPredicate toPMMLPredicate() {
PMMLCompoundPredicate compound = new PMMLCompoundPredicate(PMMLBooleanOperator.OR);
switch(m_numericOperator) {
case LargerThanOrMissing:
compound.addPredicate(new PMMLSimplePredicate(getAttributeName(), PMMLOperator.GREATER_THAN, Double.toString(m_splitValue)));
compound.addPredicate(new PMMLSimplePredicate(getAttributeName(), PMMLOperator.IS_MISSING, Double.toString(m_splitValue)));
return compound;
case LessThanOrEqualOrMissing:
compound.addPredicate(new PMMLSimplePredicate(getAttributeName(), PMMLOperator.LESS_OR_EQUAL, Double.toString(m_splitValue)));
compound.addPredicate(new PMMLSimplePredicate(getAttributeName(), PMMLOperator.IS_MISSING, Double.toString(m_splitValue)));
return compound;
}
final PMMLOperator pmmlOperator = m_numericOperator.m_pmmlOperator;
if (pmmlOperator == null) {
throw new IllegalStateException("There is no equivalent PMMLOperator for this NumericOperator.");
}
final PMMLSimplePredicate simplePredicate = new PMMLSimplePredicate(getAttributeName(), pmmlOperator, Double.toString(m_splitValue));
if (!acceptsMissings()) {
// return simple predicate that rejects missing values
return simplePredicate;
}
// create compound to allow for missing values
compound.addPredicate(simplePredicate);
final PMMLSimplePredicate missing = new PMMLSimplePredicate();
missing.setSplitAttribute(getAttributeName());
missing.setOperator(PMMLOperator.IS_MISSING);
compound.addPredicate(missing);
return compound;
}
Aggregations