use of org.dmg.pmml.ArrayType in project knime-core by knime.
the class TreeModelPMMLTranslator method setValuesFromPMMLSimpleSetPredicate.
private static void setValuesFromPMMLSimpleSetPredicate(final SimpleSetPredicate to, final PMMLSimpleSetPredicate from) {
to.setField(from.getSplitAttribute());
final Enum operator;
final PMMLSetOperator setOp = from.getSetOperator();
switch(setOp) {
case IS_IN:
operator = SimpleSetPredicate.BooleanOperator.IS_IN;
break;
case IS_NOT_IN:
operator = SimpleSetPredicate.BooleanOperator.IS_NOT_IN;
break;
default:
throw new IllegalStateException("Unknown set operator \"" + setOp + "\".");
}
to.setBooleanOperator(operator);
final Set<String> values = from.getValues();
ArrayType array = to.addNewArray();
array.setN(BigInteger.valueOf(values.size()));
org.w3c.dom.Node arrayNode = array.getDomNode();
arrayNode.appendChild(arrayNode.getOwnerDocument().createTextNode(setToWhitspaceSeparatedString(values)));
final org.dmg.pmml.ArrayType.Type.Enum type;
final PMMLArrayType arrayType = from.getArrayType();
switch(arrayType) {
case INT:
type = ArrayType.Type.INT;
break;
case REAL:
type = ArrayType.Type.REAL;
break;
case STRING:
type = ArrayType.Type.STRING;
break;
default:
throw new IllegalStateException("Unknown array type \"" + arrayType + "\".");
}
array.setType(type);
}
use of org.dmg.pmml.ArrayType in project knime-core by knime.
the class PMMLClusterTranslator method exportTo.
/**
* {@inheritDoc}
*/
@Override
public SchemaType exportTo(final PMMLDocument pmmlDoc, final PMMLPortObjectSpec spec) {
DerivedFieldMapper mapper = new DerivedFieldMapper(pmmlDoc);
PMML pmml = pmmlDoc.getPMML();
ClusteringModelDocument.ClusteringModel clusteringModel = pmml.addNewClusteringModel();
PMMLMiningSchemaTranslator.writeMiningSchema(spec, clusteringModel);
// ---------------------------------------------------
// set clustering model attributes
clusteringModel.setModelName("k-means");
clusteringModel.setFunctionName(MININGFUNCTION.CLUSTERING);
clusteringModel.setModelClass(ModelClass.CENTER_BASED);
clusteringModel.setNumberOfClusters(BigInteger.valueOf(m_nrOfClusters));
// ---------------------------------------------------
// set comparison measure
ComparisonMeasureDocument.ComparisonMeasure pmmlComparisonMeasure = clusteringModel.addNewComparisonMeasure();
pmmlComparisonMeasure.setKind(Kind.DISTANCE);
if (ComparisonMeasure.squaredEuclidean.equals(m_measure)) {
pmmlComparisonMeasure.addNewSquaredEuclidean();
} else {
pmmlComparisonMeasure.addNewEuclidean();
}
// set clustering fields
for (String colName : m_usedColumns) {
ClusteringFieldDocument.ClusteringField pmmlClusteringField = clusteringModel.addNewClusteringField();
pmmlClusteringField.setField(mapper.getDerivedFieldName(colName));
pmmlClusteringField.setCompareFunction(COMPAREFUNCTION.ABS_DIFF);
}
// ----------------------------------------------------
// set clusters
int i = 0;
for (double[] prototype : m_prototypes) {
ClusterDocument.Cluster pmmlCluster = clusteringModel.addNewCluster();
String name = CLUSTER_NAME_PREFIX + i;
pmmlCluster.setName(name);
if (m_clusterCoverage != null && m_clusterCoverage.length == m_prototypes.length) {
pmmlCluster.setSize(BigInteger.valueOf(m_clusterCoverage[i]));
}
i++;
ArrayType pmmlArray = pmmlCluster.addNewArray();
pmmlArray.setN(BigInteger.valueOf(prototype.length));
pmmlArray.setType(Type.REAL);
StringBuffer buff = new StringBuffer();
for (double d : prototype) {
buff.append(d + " ");
}
XmlCursor xmlCursor = pmmlArray.newCursor();
xmlCursor.setTextValue(buff.toString());
xmlCursor.dispose();
}
return ClusteringModel.type;
}
use of org.dmg.pmml.ArrayType in project knime-core by knime.
the class PMMLClusterTranslator method initializeFrom.
/**
* {@inheritDoc}
*/
@Override
public void initializeFrom(final PMMLDocument pmmlDoc) {
PMML pmml = pmmlDoc.getPMML();
DerivedFieldMapper mapper = new DerivedFieldMapper(pmmlDoc);
ClusteringModelDocument.ClusteringModel pmmlClusteringModel = pmml.getClusteringModelArray(0);
// initialize ClusteringFields
for (ClusteringField cf : pmmlClusteringModel.getClusteringFieldArray()) {
m_usedColumns.add(mapper.getColumnName(cf.getField()));
if (COMPAREFUNCTION.ABS_DIFF != cf.getCompareFunction()) {
LOGGER.error("Comparison Function " + cf.getCompareFunction().toString() + " is not supported!");
throw new IllegalArgumentException("Only the absolute difference (\"absDiff\") as " + "compare function is supported!");
}
}
// ---------------------------------------------------
// initialize Clusters
m_nrOfClusters = pmmlClusteringModel.sizeOfClusterArray();
m_prototypes = new double[m_nrOfClusters][m_usedColumns.size()];
m_labels = new String[m_nrOfClusters];
m_clusterCoverage = new int[m_nrOfClusters];
for (int i = 0; i < m_nrOfClusters; i++) {
ClusterDocument.Cluster currentCluster = pmmlClusteringModel.getClusterArray(i);
m_labels[i] = currentCluster.getName();
// in KNIME learner: m_labels[i] = "cluster_" + i;
ArrayType clusterArray = currentCluster.getArray();
String content = clusterArray.newCursor().getTextValue();
String[] stringValues;
content = content.trim();
if (content.contains(DOUBLE_QUOT)) {
content = content.replace(BACKSLASH + DOUBLE_QUOT, TAB);
/* TODO We need to take care of the cases with double quots,
* e.g ==> <Array n="3" type="string">"Cheval Blanc" "TABTAB"
"Latour"</Array> */
stringValues = content.split(DOUBLE_QUOT + SPACE);
for (int j = 0; j < stringValues.length; j++) {
stringValues[j] = stringValues[j].replace(DOUBLE_QUOT, "");
stringValues[j] = stringValues[j].replace(TAB, DOUBLE_QUOT);
stringValues[j] = stringValues[j].trim();
}
} else {
stringValues = content.split("\\s+");
}
for (int j = 0; j < m_usedColumns.size(); j++) {
m_prototypes[i][j] = Double.valueOf(stringValues[j]);
}
if (currentCluster.isSetSize()) {
m_clusterCoverage[i] = currentCluster.getSize().intValue();
}
}
if (pmmlClusteringModel.isSetMissingValueWeights()) {
ArrayType weights = pmmlClusteringModel.getMissingValueWeights().getArray();
String content = weights.newCursor().getTextValue();
String[] stringValues;
Double[] weightValues;
content = content.trim();
if (content.contains(DOUBLE_QUOT)) {
content = content.replace(BACKSLASH + DOUBLE_QUOT, TAB);
/* TODO We need to take care of the cases with double quots,
* e.g ==> <Array n="3" type="string">"Cheval Blanc" "TABTAB"
"Latour"</Array> */
stringValues = content.split(DOUBLE_QUOT + SPACE);
weightValues = new Double[stringValues.length];
for (int j = 0; j < stringValues.length; j++) {
stringValues[j] = stringValues[j].replace(DOUBLE_QUOT, "");
stringValues[j] = stringValues[j].replace(TAB, DOUBLE_QUOT);
stringValues[j] = stringValues[j].trim();
weightValues[j] = Double.valueOf(stringValues[j]);
if (weightValues[j] == null || weightValues[j].doubleValue() != 1.0) {
String msg = "Missing Value Weight not equals one" + " is not supported!";
LOGGER.error(msg);
}
}
} else {
stringValues = content.split("\\s+");
}
}
// ------------------------------------------
// initialize m_usedColumns from ClusteringField
ClusteringFieldDocument.ClusteringField[] clusteringFieldArray = pmmlClusteringModel.getClusteringFieldArray();
for (ClusteringField cf : clusteringFieldArray) {
m_usedColumns.add(mapper.getColumnName(cf.getField()));
}
// --------------------------------------------
// initialize Comparison Measure
ComparisonMeasureDocument.ComparisonMeasure pmmlComparisonMeasure = pmmlClusteringModel.getComparisonMeasure();
if (pmmlComparisonMeasure.isSetSquaredEuclidean()) {
m_measure = ComparisonMeasure.squaredEuclidean;
} else if (pmmlComparisonMeasure.isSetEuclidean()) {
m_measure = ComparisonMeasure.euclidean;
} else {
String measure = pmmlComparisonMeasure.getDomNode().getFirstChild().getNodeName();
throw new IllegalArgumentException("\"" + ComparisonMeasure.euclidean + "\" and \"" + ComparisonMeasure.squaredEuclidean + "\" are the only supported comparison " + "measures! Found " + measure + ".");
}
if (Kind.SIMILARITY == pmmlComparisonMeasure.getKind()) {
LOGGER.error("A Similarity Kind of Comparison Measure is not " + "supported!");
}
}
use of org.dmg.pmml.ArrayType in project knime-core by knime.
the class PMMLDecisionTreeTranslator method addTreeNode.
/**
* A recursive function which converts each KNIME Tree node to a
* corresponding PMML element.
*
* @param pmmlNode the desired PMML element
* @param node A KNIME DecisionTree node
*/
private static void addTreeNode(final NodeDocument.Node pmmlNode, final DecisionTreeNode node, final DerivedFieldMapper mapper) {
pmmlNode.setId(String.valueOf(node.getOwnIndex()));
pmmlNode.setScore(node.getMajorityClass().toString());
// read in and then exported again
if (node.getEntireClassCount() > 0) {
pmmlNode.setRecordCount(node.getEntireClassCount());
}
if (node instanceof DecisionTreeNodeSplitPMML) {
int defaultChild = ((DecisionTreeNodeSplitPMML) node).getDefaultChildIndex();
if (defaultChild > -1) {
pmmlNode.setDefaultChild(String.valueOf(defaultChild));
}
}
// adding score and stuff from parent
DecisionTreeNode parent = node.getParent();
if (parent == null) {
// When the parent is null, it is the root Node.
// For root node, the predicate is always True.
pmmlNode.addNewTrue();
} else if (parent instanceof DecisionTreeNodeSplitContinuous) {
// SimplePredicate case
DecisionTreeNodeSplitContinuous splitNode = (DecisionTreeNodeSplitContinuous) parent;
if (splitNode.getIndex(node) == 0) {
SimplePredicate pmmlSimplePredicate = pmmlNode.addNewSimplePredicate();
pmmlSimplePredicate.setField(mapper.getDerivedFieldName(splitNode.getSplitAttr()));
pmmlSimplePredicate.setOperator(Operator.LESS_OR_EQUAL);
pmmlSimplePredicate.setValue(String.valueOf(splitNode.getThreshold()));
} else if (splitNode.getIndex(node) == 1) {
pmmlNode.addNewTrue();
}
} else if (parent instanceof DecisionTreeNodeSplitNominalBinary) {
// SimpleSetPredicate case
DecisionTreeNodeSplitNominalBinary splitNode = (DecisionTreeNodeSplitNominalBinary) parent;
SimpleSetPredicate pmmlSimpleSetPredicate = pmmlNode.addNewSimpleSetPredicate();
pmmlSimpleSetPredicate.setField(mapper.getDerivedFieldName(splitNode.getSplitAttr()));
pmmlSimpleSetPredicate.setBooleanOperator(SimpleSetPredicate.BooleanOperator.IS_IN);
ArrayType pmmlArray = pmmlSimpleSetPredicate.addNewArray();
pmmlArray.setType(ArrayType.Type.STRING);
DataCell[] splitValues = splitNode.getSplitValues();
List<Integer> indices = null;
if (splitNode.getIndex(node) == SplitNominalBinary.LEFT_PARTITION) {
indices = splitNode.getLeftChildIndices();
} else if (splitNode.getIndex(node) == SplitNominalBinary.RIGHT_PARTITION) {
indices = splitNode.getRightChildIndices();
} else {
throw new IllegalArgumentException("Split node is neither " + "contained in the right nor in the left partition.");
}
StringBuilder classSet = new StringBuilder();
for (Integer i : indices) {
if (classSet.length() > 0) {
classSet.append(" ");
}
classSet.append(splitValues[i].toString());
}
pmmlArray.setN(BigInteger.valueOf(indices.size()));
XmlCursor xmlCursor = pmmlArray.newCursor();
xmlCursor.setTextValue(classSet.toString());
xmlCursor.dispose();
} else if (parent instanceof DecisionTreeNodeSplitNominal) {
DecisionTreeNodeSplitNominal splitNode = (DecisionTreeNodeSplitNominal) parent;
SimplePredicate pmmlSimplePredicate = pmmlNode.addNewSimplePredicate();
pmmlSimplePredicate.setField(mapper.getDerivedFieldName(splitNode.getSplitAttr()));
pmmlSimplePredicate.setOperator(Operator.EQUAL);
int nodeIndex = parent.getIndex(node);
pmmlSimplePredicate.setValue(String.valueOf(splitNode.getSplitValues()[nodeIndex].toString()));
} else if (parent instanceof DecisionTreeNodeSplitPMML) {
DecisionTreeNodeSplitPMML splitNode = (DecisionTreeNodeSplitPMML) parent;
int nodeIndex = parent.getIndex(node);
// get the PMML predicate of the current node from its parent
PMMLPredicate predicate = splitNode.getSplitPred()[nodeIndex];
if (predicate instanceof PMMLCompoundPredicate) {
// surrogates as used in GBT
exportCompoundPredicate(pmmlNode, (PMMLCompoundPredicate) predicate, mapper);
} else {
predicate.setSplitAttribute(mapper.getDerivedFieldName(predicate.getSplitAttribute()));
// delegate the writing to the predicate translator
PMMLPredicateTranslator.exportTo(predicate, pmmlNode);
}
} else {
throw new IllegalArgumentException("Node Type " + parent.getClass() + " is not supported!");
}
// adding score distribution (class counts)
Set<Entry<DataCell, Double>> classCounts = node.getClassCounts().entrySet();
Iterator<Entry<DataCell, Double>> iterator = classCounts.iterator();
while (iterator.hasNext()) {
Entry<DataCell, Double> entry = iterator.next();
DataCell cell = entry.getKey();
Double freq = entry.getValue();
ScoreDistribution pmmlScoreDist = pmmlNode.addNewScoreDistribution();
pmmlScoreDist.setValue(cell.toString());
pmmlScoreDist.setRecordCount(freq);
}
// adding children
if (!(node instanceof DecisionTreeNodeLeaf)) {
for (int i = 0; i < node.getChildCount(); i++) {
addTreeNode(pmmlNode.addNewNode(), node.getChildAt(i), mapper);
}
}
}
use of org.dmg.pmml.ArrayType in project knime-core by knime.
the class PMMLPredicateTranslator method initSimpleSetPred.
/**
* Converts a {@link PMMLSimpleSetPredicate} ({@code sp}) to a {@link SimpleSetPredicate} ({@code setPred}).
*
* @param sp A {@link PMMLSimpleSetPredicate}.
* @param setPred The {@link SimpleSetPredicate} to initialize.
* @since 2.9
*/
public static void initSimpleSetPred(final PMMLSimpleSetPredicate sp, final SimpleSetPredicate setPred) {
setPred.setField(sp.getSplitAttribute());
setPred.setBooleanOperator(getOperator(sp.getSetOperator()));
ArrayType array = setPred.addNewArray();
array.setN(BigInteger.valueOf(sp.getValues().size()));
array.setType(getType(sp.getArrayType()));
// how to set content?
StringBuffer sb = new StringBuffer();
if (sp.getArrayType() == PMMLArrayType.STRING) {
for (String value : sp.getValues()) {
sb.append('"');
sb.append(value.replace("\"", "\\\""));
sb.append('"');
sb.append(' ');
}
} else {
for (String value : sp.getValues()) {
sb.append(value);
sb.append(' ');
}
}
XmlCursor xmlCursor = array.newCursor();
xmlCursor.setTextValue(sb.toString());
xmlCursor.dispose();
}
Aggregations