use of org.dmg.pmml.RowDocument.Row in project knime-core by knime.
the class PMMLKNNTranslator method exportTo.
/**
* {@inheritDoc}
*/
@Override
public SchemaType exportTo(final PMMLDocument pmmlDoc, final PMMLPortObjectSpec spec) {
LinkedHashMap<Integer, String> columnNames = new LinkedHashMap<Integer, String>();
DataTableSpec tSpec = m_table.getDataTableSpec();
// Find learning columns and store them in the map for later
for (String lc : m_includes) {
columnNames.put(tSpec.findColumnIndex(lc), "col" + columnNames.size());
}
// Create initial XML elements
PMML pmml = pmmlDoc.getPMML();
NearestNeighborModel knn = pmml.addNewNearestNeighborModel();
PMMLMiningSchemaTranslator.writeMiningSchema(spec, knn);
knn.setAlgorithmName("K-Nearest Neighbors");
knn.setFunctionName(org.dmg.pmml.MININGFUNCTION.CLASSIFICATION);
knn.setNumberOfNeighbors(BigInteger.valueOf(m_numNeighbors));
// Only euclidean is supported so far
ComparisonMeasure cm = knn.addNewComparisonMeasure();
cm.addNewEuclidean();
// KNNInputs is a list of the fields used for determining the distance
KNNInputs inputs = knn.addNewKNNInputs();
for (int i : columnNames.keySet()) {
KNNInput input = inputs.addNewKNNInput();
String col = tSpec.getColumnSpec(i).getName();
input.setField(col);
input.setCompareFunction(COMPAREFUNCTION.ABS_DIFF);
}
TrainingInstances ti = knn.addNewTrainingInstances();
// Here we create a mapping from column name to name of the XML element for the column's values
InstanceFields instanceFields = ti.addNewInstanceFields();
for (int i : columnNames.keySet()) {
InstanceField instanceField = instanceFields.addNewInstanceField();
String col = tSpec.getColumnSpec(i).getName();
instanceField.setField(col);
instanceField.setColumn(columnNames.get(i));
}
int targetIdx = tSpec.findColumnIndex(spec.getTargetFields().get(0));
InstanceField target = instanceFields.addNewInstanceField();
target.setField(spec.getTargetFields().get(0));
target.setColumn("target");
// The inline table holds the actual data.
// We use the map we created in the beginning to determine the element xml-element-names
InlineTable it = ti.addNewInlineTable();
Document doc = it.getDomNode().getOwnerDocument();
int counter = 0;
for (DataRow row : m_table) {
// Stop if we have reached the maximum number of records
if (m_maxRecords > -1 && ++counter > m_maxRecords) {
break;
}
Row inlineRow = it.addNewRow();
Element rowNode = (Element) inlineRow.getDomNode();
for (int col : columnNames.keySet()) {
Element field = doc.createElementNS(PMMLUtils.getPMMLCurrentVersionNamespace(), columnNames.get(col));
field.appendChild(doc.createTextNode(row.getCell(col).toString()));
rowNode.appendChild(field);
}
Element targetField = doc.createElementNS(PMMLUtils.getPMMLCurrentVersionNamespace(), "target");
targetField.appendChild(doc.createTextNode(row.getCell(targetIdx).toString()));
rowNode.appendChild(targetField);
}
return NearestNeighborModel.type;
}
use of org.dmg.pmml.RowDocument.Row in project knime-core by knime.
the class PMMLMapValuesTranslator method createDerivedFields.
private DerivedField[] createDerivedFields() {
DerivedField df = DerivedField.Factory.newInstance();
df.setExtensionArray(createSummaryExtension());
/* The field name must be retrieved before creating a new derived
* name for this derived field as the map only contains the
* current mapping. */
String fieldName = m_mapper.getDerivedFieldName(m_config.getInColumn());
if (m_config.getInColumn().equals(m_config.getOutColumn())) {
String name = m_config.getInColumn();
df.setDisplayName(name);
df.setName(m_mapper.createDerivedFieldName(name));
} else {
df.setName(m_config.getOutColumn());
}
df.setOptype(m_config.getOpType());
df.setDataType(m_config.getOutDataType());
MapValues mapValues = df.addNewMapValues();
// the element in the InlineTable representing the output column
// Use dummy name instead of m_config.getOutColumn() since the
// input column could contain characters that are not allowed in XML
final QName xmlOut = new QName("http://www.dmg.org/PMML-4_0", "out");
mapValues.setOutputColumn(xmlOut.getLocalPart());
mapValues.setDataType(m_config.getOutDataType());
if (!m_config.getDefaultValue().isMissing()) {
mapValues.setDefaultValue(m_config.getDefaultValue().toString());
}
if (!m_config.getMapMissingTo().isMissing()) {
mapValues.setMapMissingTo(m_config.getMapMissingTo().toString());
}
// the mapping of input field <-> element in the InlineTable
FieldColumnPair fieldColPair = mapValues.addNewFieldColumnPair();
fieldColPair.setField(fieldName);
// Use dummy name instead of m_config.getInColumn() since the
// input column could contain characters that are not allowed in XML
final QName xmlIn = new QName("http://www.dmg.org/PMML-4_0", "in");
fieldColPair.setColumn(xmlIn.getLocalPart());
InlineTable table = mapValues.addNewInlineTable();
for (Entry<DataCell, ? extends DataCell> entry : m_config.getEntries().entrySet()) {
Row row = table.addNewRow();
XmlCursor cursor = row.newCursor();
cursor.toNextToken();
cursor.insertElementWithText(xmlIn, entry.getKey().toString());
cursor.insertElementWithText(xmlOut, entry.getValue().toString());
cursor.dispose();
}
return new DerivedField[] { df };
}
use of org.dmg.pmml.RowDocument.Row in project knime-core by knime.
the class PMMLMapValuesTranslator method createDerivedFields.
private DerivedField[] createDerivedFields() {
DerivedField df = DerivedField.Factory.newInstance();
df.setExtensionArray(createSummaryExtension());
/* The field name must be retrieved before creating a new derived
* name for this derived field as the map only contains the
* current mapping. */
String fieldName = m_mapper.getDerivedFieldName(m_config.getInColumn());
if (m_config.getInColumn().equals(m_config.getOutColumn())) {
String name = m_config.getInColumn();
df.setDisplayName(name);
df.setName(m_mapper.createDerivedFieldName(name));
} else {
df.setName(m_config.getOutColumn());
}
df.setOptype(m_config.getOpType());
df.setDataType(m_config.getOutDataType());
MapValues mapValues = df.addNewMapValues();
// the element in the InlineTable representing the output column
// Use dummy name instead of m_config.getOutColumn() since the
// input column could contain characters that are not allowed in XML
final QName xmlOut = new QName("http://www.dmg.org/PMML-4_0", "out");
mapValues.setOutputColumn(xmlOut.getLocalPart());
mapValues.setDataType(m_config.getOutDataType());
if (!m_config.getDefaultValue().isMissing()) {
mapValues.setDefaultValue(m_config.getDefaultValue().toString());
}
if (!m_config.getMapMissingTo().isMissing()) {
mapValues.setMapMissingTo(m_config.getMapMissingTo().toString());
}
// the mapping of input field <-> element in the InlineTable
FieldColumnPair fieldColPair = mapValues.addNewFieldColumnPair();
fieldColPair.setField(fieldName);
// Use dummy name instead of m_config.getInColumn() since the
// input column could contain characters that are not allowed in XML
final QName xmlIn = new QName("http://www.dmg.org/PMML-4_0", "in");
fieldColPair.setColumn(xmlIn.getLocalPart());
InlineTable table = mapValues.addNewInlineTable();
for (Entry<DataCell, ? extends DataCell> entry : m_config.getEntries().entrySet()) {
Row row = table.addNewRow();
XmlCursor cursor = row.newCursor();
cursor.toNextToken();
cursor.insertElementWithText(xmlIn, entry.getKey().toString());
cursor.insertElementWithText(xmlOut, entry.getValue().toString());
cursor.dispose();
}
return new DerivedField[] { df };
}
Aggregations