use of org.dmg.pmml.PMMLDocument.PMML in project knime-core by knime.
the class TreeModelPMMLTranslator method exportTo.
/**
* {@inheritDoc}
*/
@Override
public SchemaType exportTo(final PMMLDocument pmmlDoc, final PMMLPortObjectSpec spec) {
PMML pmml = pmmlDoc.getPMML();
TreeModelDocument.TreeModel treeModel = pmml.addNewTreeModel();
PMMLMiningSchemaTranslator.writeMiningSchema(spec, treeModel);
treeModel.setModelName("DecisionTree");
if (m_treeModel instanceof TreeModelClassification) {
treeModel.setFunctionName(MININGFUNCTION.CLASSIFICATION);
} else if (m_treeModel instanceof TreeModelRegression) {
treeModel.setFunctionName(MININGFUNCTION.REGRESSION);
} else {
throw new IllegalStateException("Unknown tree model \"" + m_treeModel.getClass().getSimpleName() + "\".");
}
AbstractTreeNode rootNode = m_treeModel.getRootNode();
// set up splitCharacteristic
if (isMultiSplitRecursive(rootNode)) {
treeModel.setSplitCharacteristic(SplitCharacteristic.MULTI_SPLIT);
} else {
treeModel.setSplitCharacteristic(SplitCharacteristic.BINARY_SPLIT);
}
// ----------------------------------------------
// set up missing value strategy
treeModel.setMissingValueStrategy(MISSINGVALUESTRATEGY.NONE);
// -------------------------------------------------
// set up no true child strategy
treeModel.setNoTrueChildStrategy(NOTRUECHILDSTRATEGY.RETURN_LAST_PREDICTION);
// --------------------------------------------------
// set up tree node
NodeDocument.Node rootPMMLNode = treeModel.addNewNode();
addTreeNode(rootPMMLNode, rootNode);
return TreeModelDocument.TreeModel.type;
}
use of org.dmg.pmml.PMMLDocument.PMML in project knime-core by knime.
the class PMMLClusterTranslator method exportTo.
/**
* {@inheritDoc}
*/
@Override
public SchemaType exportTo(final PMMLDocument pmmlDoc, final PMMLPortObjectSpec spec) {
DerivedFieldMapper mapper = new DerivedFieldMapper(pmmlDoc);
PMML pmml = pmmlDoc.getPMML();
ClusteringModelDocument.ClusteringModel clusteringModel = pmml.addNewClusteringModel();
PMMLMiningSchemaTranslator.writeMiningSchema(spec, clusteringModel);
// ---------------------------------------------------
// set clustering model attributes
clusteringModel.setModelName("k-means");
clusteringModel.setFunctionName(MININGFUNCTION.CLUSTERING);
clusteringModel.setModelClass(ModelClass.CENTER_BASED);
clusteringModel.setNumberOfClusters(BigInteger.valueOf(m_nrOfClusters));
// ---------------------------------------------------
// set comparison measure
ComparisonMeasureDocument.ComparisonMeasure pmmlComparisonMeasure = clusteringModel.addNewComparisonMeasure();
pmmlComparisonMeasure.setKind(Kind.DISTANCE);
if (ComparisonMeasure.squaredEuclidean.equals(m_measure)) {
pmmlComparisonMeasure.addNewSquaredEuclidean();
} else {
pmmlComparisonMeasure.addNewEuclidean();
}
// set clustering fields
for (String colName : m_usedColumns) {
ClusteringFieldDocument.ClusteringField pmmlClusteringField = clusteringModel.addNewClusteringField();
pmmlClusteringField.setField(mapper.getDerivedFieldName(colName));
pmmlClusteringField.setCompareFunction(COMPAREFUNCTION.ABS_DIFF);
}
// ----------------------------------------------------
// set clusters
int i = 0;
for (double[] prototype : m_prototypes) {
ClusterDocument.Cluster pmmlCluster = clusteringModel.addNewCluster();
String name = CLUSTER_NAME_PREFIX + i;
pmmlCluster.setName(name);
if (m_clusterCoverage != null && m_clusterCoverage.length == m_prototypes.length) {
pmmlCluster.setSize(BigInteger.valueOf(m_clusterCoverage[i]));
}
i++;
ArrayType pmmlArray = pmmlCluster.addNewArray();
pmmlArray.setN(BigInteger.valueOf(prototype.length));
pmmlArray.setType(Type.REAL);
StringBuffer buff = new StringBuffer();
for (double d : prototype) {
buff.append(d + " ");
}
XmlCursor xmlCursor = pmmlArray.newCursor();
xmlCursor.setTextValue(buff.toString());
xmlCursor.dispose();
}
return ClusteringModel.type;
}
use of org.dmg.pmml.PMMLDocument.PMML in project knime-core by knime.
the class PMMLClusterTranslator method initializeFrom.
/**
* {@inheritDoc}
*/
@Override
public void initializeFrom(final PMMLDocument pmmlDoc) {
PMML pmml = pmmlDoc.getPMML();
DerivedFieldMapper mapper = new DerivedFieldMapper(pmmlDoc);
ClusteringModelDocument.ClusteringModel pmmlClusteringModel = pmml.getClusteringModelArray(0);
// initialize ClusteringFields
for (ClusteringField cf : pmmlClusteringModel.getClusteringFieldArray()) {
m_usedColumns.add(mapper.getColumnName(cf.getField()));
if (COMPAREFUNCTION.ABS_DIFF != cf.getCompareFunction()) {
LOGGER.error("Comparison Function " + cf.getCompareFunction().toString() + " is not supported!");
throw new IllegalArgumentException("Only the absolute difference (\"absDiff\") as " + "compare function is supported!");
}
}
// ---------------------------------------------------
// initialize Clusters
m_nrOfClusters = pmmlClusteringModel.sizeOfClusterArray();
m_prototypes = new double[m_nrOfClusters][m_usedColumns.size()];
m_labels = new String[m_nrOfClusters];
m_clusterCoverage = new int[m_nrOfClusters];
for (int i = 0; i < m_nrOfClusters; i++) {
ClusterDocument.Cluster currentCluster = pmmlClusteringModel.getClusterArray(i);
m_labels[i] = currentCluster.getName();
// in KNIME learner: m_labels[i] = "cluster_" + i;
ArrayType clusterArray = currentCluster.getArray();
String content = clusterArray.newCursor().getTextValue();
String[] stringValues;
content = content.trim();
if (content.contains(DOUBLE_QUOT)) {
content = content.replace(BACKSLASH + DOUBLE_QUOT, TAB);
/* TODO We need to take care of the cases with double quots,
* e.g ==> <Array n="3" type="string">"Cheval Blanc" "TABTAB"
"Latour"</Array> */
stringValues = content.split(DOUBLE_QUOT + SPACE);
for (int j = 0; j < stringValues.length; j++) {
stringValues[j] = stringValues[j].replace(DOUBLE_QUOT, "");
stringValues[j] = stringValues[j].replace(TAB, DOUBLE_QUOT);
stringValues[j] = stringValues[j].trim();
}
} else {
stringValues = content.split("\\s+");
}
for (int j = 0; j < m_usedColumns.size(); j++) {
m_prototypes[i][j] = Double.valueOf(stringValues[j]);
}
if (currentCluster.isSetSize()) {
m_clusterCoverage[i] = currentCluster.getSize().intValue();
}
}
if (pmmlClusteringModel.isSetMissingValueWeights()) {
ArrayType weights = pmmlClusteringModel.getMissingValueWeights().getArray();
String content = weights.newCursor().getTextValue();
String[] stringValues;
Double[] weightValues;
content = content.trim();
if (content.contains(DOUBLE_QUOT)) {
content = content.replace(BACKSLASH + DOUBLE_QUOT, TAB);
/* TODO We need to take care of the cases with double quots,
* e.g ==> <Array n="3" type="string">"Cheval Blanc" "TABTAB"
"Latour"</Array> */
stringValues = content.split(DOUBLE_QUOT + SPACE);
weightValues = new Double[stringValues.length];
for (int j = 0; j < stringValues.length; j++) {
stringValues[j] = stringValues[j].replace(DOUBLE_QUOT, "");
stringValues[j] = stringValues[j].replace(TAB, DOUBLE_QUOT);
stringValues[j] = stringValues[j].trim();
weightValues[j] = Double.valueOf(stringValues[j]);
if (weightValues[j] == null || weightValues[j].doubleValue() != 1.0) {
String msg = "Missing Value Weight not equals one" + " is not supported!";
LOGGER.error(msg);
}
}
} else {
stringValues = content.split("\\s+");
}
}
// ------------------------------------------
// initialize m_usedColumns from ClusteringField
ClusteringFieldDocument.ClusteringField[] clusteringFieldArray = pmmlClusteringModel.getClusteringFieldArray();
for (ClusteringField cf : clusteringFieldArray) {
m_usedColumns.add(mapper.getColumnName(cf.getField()));
}
// --------------------------------------------
// initialize Comparison Measure
ComparisonMeasureDocument.ComparisonMeasure pmmlComparisonMeasure = pmmlClusteringModel.getComparisonMeasure();
if (pmmlComparisonMeasure.isSetSquaredEuclidean()) {
m_measure = ComparisonMeasure.squaredEuclidean;
} else if (pmmlComparisonMeasure.isSetEuclidean()) {
m_measure = ComparisonMeasure.euclidean;
} else {
String measure = pmmlComparisonMeasure.getDomNode().getFirstChild().getNodeName();
throw new IllegalArgumentException("\"" + ComparisonMeasure.euclidean + "\" and \"" + ComparisonMeasure.squaredEuclidean + "\" are the only supported comparison " + "measures! Found " + measure + ".");
}
if (Kind.SIMILARITY == pmmlComparisonMeasure.getKind()) {
LOGGER.error("A Similarity Kind of Comparison Measure is not " + "supported!");
}
}
use of org.dmg.pmml.PMMLDocument.PMML in project knime-core by knime.
the class PMMLDecisionTreeTranslator method exportTo.
/**
* {@inheritDoc}
*/
@Override
public SchemaType exportTo(final PMMLDocument pmmlDoc, final PMMLPortObjectSpec spec) {
m_nameMapper = new DerivedFieldMapper(pmmlDoc);
PMML pmml = pmmlDoc.getPMML();
TreeModelDocument.TreeModel treeModel = pmml.addNewTreeModel();
PMMLMiningSchemaTranslator.writeMiningSchema(spec, treeModel);
treeModel.setModelName("DecisionTree");
if (m_isClassification) {
treeModel.setFunctionName(MININGFUNCTION.CLASSIFICATION);
} else {
treeModel.setFunctionName(MININGFUNCTION.REGRESSION);
}
// set up splitCharacteristic
if (treeIsMultisplit(m_tree.getRootNode())) {
treeModel.setSplitCharacteristic(SplitCharacteristic.MULTI_SPLIT);
} else {
treeModel.setSplitCharacteristic(SplitCharacteristic.BINARY_SPLIT);
}
// ----------------------------------------------
// set up missing value strategy
PMMLMissingValueStrategy mvStrategy = m_tree.getMVStrategy() != null ? m_tree.getMVStrategy() : PMMLMissingValueStrategy.NONE;
treeModel.setMissingValueStrategy(MV_STRATEGY_TO_PMML_MAP.get(mvStrategy));
// -------------------------------------------------
// set up no true child strategy
PMMLNoTrueChildStrategy ntcStrategy = m_tree.getNTCStrategy();
if (PMMLNoTrueChildStrategy.RETURN_LAST_PREDICTION.equals(ntcStrategy)) {
treeModel.setNoTrueChildStrategy(NOTRUECHILDSTRATEGY.RETURN_LAST_PREDICTION);
} else if (PMMLNoTrueChildStrategy.RETURN_NULL_PREDICTION.equals(ntcStrategy)) {
treeModel.setNoTrueChildStrategy(NOTRUECHILDSTRATEGY.RETURN_NULL_PREDICTION);
}
// --------------------------------------------------
// set up tree node
NodeDocument.Node rootNode = treeModel.addNewNode();
addTreeNode(rootNode, m_tree.getRootNode(), new DerivedFieldMapper(pmmlDoc));
return TreeModel.type;
}
use of org.dmg.pmml.PMMLDocument.PMML in project knime-core by knime.
the class PMMLKNNTranslator method exportTo.
/**
* {@inheritDoc}
*/
@Override
public SchemaType exportTo(final PMMLDocument pmmlDoc, final PMMLPortObjectSpec spec) {
LinkedHashMap<Integer, String> columnNames = new LinkedHashMap<Integer, String>();
DataTableSpec tSpec = m_table.getDataTableSpec();
// Find learning columns and store them in the map for later
for (String lc : m_includes) {
columnNames.put(tSpec.findColumnIndex(lc), "col" + columnNames.size());
}
// Create initial XML elements
PMML pmml = pmmlDoc.getPMML();
NearestNeighborModel knn = pmml.addNewNearestNeighborModel();
PMMLMiningSchemaTranslator.writeMiningSchema(spec, knn);
knn.setAlgorithmName("K-Nearest Neighbors");
knn.setFunctionName(org.dmg.pmml.MININGFUNCTION.CLASSIFICATION);
knn.setNumberOfNeighbors(BigInteger.valueOf(m_numNeighbors));
// Only euclidean is supported so far
ComparisonMeasure cm = knn.addNewComparisonMeasure();
cm.addNewEuclidean();
// KNNInputs is a list of the fields used for determining the distance
KNNInputs inputs = knn.addNewKNNInputs();
for (int i : columnNames.keySet()) {
KNNInput input = inputs.addNewKNNInput();
String col = tSpec.getColumnSpec(i).getName();
input.setField(col);
input.setCompareFunction(COMPAREFUNCTION.ABS_DIFF);
}
TrainingInstances ti = knn.addNewTrainingInstances();
// Here we create a mapping from column name to name of the XML element for the column's values
InstanceFields instanceFields = ti.addNewInstanceFields();
for (int i : columnNames.keySet()) {
InstanceField instanceField = instanceFields.addNewInstanceField();
String col = tSpec.getColumnSpec(i).getName();
instanceField.setField(col);
instanceField.setColumn(columnNames.get(i));
}
int targetIdx = tSpec.findColumnIndex(spec.getTargetFields().get(0));
InstanceField target = instanceFields.addNewInstanceField();
target.setField(spec.getTargetFields().get(0));
target.setColumn("target");
// The inline table holds the actual data.
// We use the map we created in the beginning to determine the element xml-element-names
InlineTable it = ti.addNewInlineTable();
Document doc = it.getDomNode().getOwnerDocument();
int counter = 0;
for (DataRow row : m_table) {
// Stop if we have reached the maximum number of records
if (m_maxRecords > -1 && ++counter > m_maxRecords) {
break;
}
Row inlineRow = it.addNewRow();
Element rowNode = (Element) inlineRow.getDomNode();
for (int col : columnNames.keySet()) {
Element field = doc.createElementNS(PMMLUtils.getPMMLCurrentVersionNamespace(), columnNames.get(col));
field.appendChild(doc.createTextNode(row.getCell(col).toString()));
rowNode.appendChild(field);
}
Element targetField = doc.createElementNS(PMMLUtils.getPMMLCurrentVersionNamespace(), "target");
targetField.appendChild(doc.createTextNode(row.getCell(targetIdx).toString()));
rowNode.appendChild(targetField);
}
return NearestNeighborModel.type;
}
Aggregations