use of org.dmg.pmml.DataDictionaryDocument.DataDictionary in project knime-core by knime.
the class PMMLRuleTranslator method initDataDictionary.
/**
* Inits {@link #m_dataDictionary} based on the {@code pmmlDoc} document.
*
* @param pmmlDoc A {@link PMMLDocument}.
*/
private void initDataDictionary(final PMMLDocument pmmlDoc) {
DataDictionary dd = pmmlDoc.getPMML().getDataDictionary();
if (dd == null) {
m_dataDictionary = Collections.emptyMap();
return;
}
Map<String, List<String>> dataDictionary = new LinkedHashMap<String, List<String>>(dd.sizeOfDataFieldArray() * 2);
for (DataField df : dd.getDataFieldList()) {
List<String> list = new ArrayList<String>(df.sizeOfValueArray());
for (Value val : df.getValueList()) {
list.add(val.getValue());
}
dataDictionary.put(df.getName(), Collections.unmodifiableList(list));
}
m_dataDictionary = Collections.unmodifiableMap(dataDictionary);
}
use of org.dmg.pmml.DataDictionaryDocument.DataDictionary in project knime-core by knime.
the class PMMLPortObject method addGlobalTransformations.
/**
* Adds global transformations to the PMML document. Only DerivedField
* elements are supported so far. If no global transformations are set so
* far the dictionary is set as new transformation dictionary, otherwise
* all contained transformations are appended to the existing one.
*
* @param dictionary the transformation dictionary that contains the
* transformations to be added
*/
public void addGlobalTransformations(final TransformationDictionary dictionary) {
// add the transformations to the TransformationDictionary
if (dictionary.getDefineFunctionArray().length > 0) {
throw new IllegalArgumentException("DefineFunctions are not " + "supported so far. Only derived fields are allowed.");
}
TransformationDictionary dict = m_pmmlDoc.getPMML().getTransformationDictionary();
if (dict == null) {
m_pmmlDoc.getPMML().setTransformationDictionary(dictionary);
dict = m_pmmlDoc.getPMML().getTransformationDictionary();
} else {
// append the transformations to the existing dictionary
DerivedField[] existingFields = dict.getDerivedFieldArray();
DerivedField[] result = appendDerivedFields(existingFields, dictionary.getDerivedFieldArray());
dict.setDerivedFieldArray(result);
}
DerivedField[] df = dict.getDerivedFieldArray();
List<String> colNames = new ArrayList<String>(df.length);
Set<String> dfNames = new HashSet<String>();
for (int i = 0; i < df.length; i++) {
String derivedName = df[i].getName();
if (dfNames.contains(derivedName)) {
throw new IllegalArgumentException("Derived field name \"" + derivedName + "\" is not unique.");
}
dfNames.add(derivedName);
String displayName = df[i].getDisplayName();
colNames.add(displayName == null ? derivedName : displayName);
}
/* Remove data fields from data dictionary that where created as a
* derived field. In KNIME the origin of columns is not distinguished
* and all columns are added to the data dictionary. But in PMML this
* results in duplicate entries. Those columns should only appear once
* as derived field in the transformation dictionary or local
* transformations. */
DataDictionary dataDict = m_pmmlDoc.getPMML().getDataDictionary();
DataField[] dataFieldArray = dataDict.getDataFieldArray();
List<DataField> dataFields = new ArrayList<DataField>(Arrays.asList(dataFieldArray));
for (DataField dataField : dataFieldArray) {
if (dfNames.contains(dataField.getName())) {
dataFields.remove(dataField);
}
}
dataDict.setDataFieldArray(dataFields.toArray(new DataField[0]));
// update the number of fields
dataDict.setNumberOfFields(BigInteger.valueOf(dataFields.size()));
// -------------------------------------------------
// update field names in the model if applicable
DerivedFieldMapper dfm = new DerivedFieldMapper(df);
Map<String, String> derivedFieldMap = dfm.getDerivedFieldMap();
/* Use XPATH to update field names in the model and move the derived
* fields to local transformations. */
PMML pmml = m_pmmlDoc.getPMML();
if (pmml.getTreeModelArray().length > 0) {
fixAttributeAtPath(pmml, TREE_PATH, FIELD, derivedFieldMap);
} else if (pmml.getClusteringModelArray().length > 0) {
fixAttributeAtPath(pmml, CLUSTERING_PATH, FIELD, derivedFieldMap);
} else if (pmml.getNeuralNetworkArray().length > 0) {
fixAttributeAtPath(pmml, NN_PATH, FIELD, derivedFieldMap);
} else if (pmml.getSupportVectorMachineModelArray().length > 0) {
fixAttributeAtPath(pmml, SVM_PATH, FIELD, derivedFieldMap);
} else if (pmml.getRegressionModelArray().length > 0) {
fixAttributeAtPath(pmml, REGRESSION_PATH_1, FIELD, derivedFieldMap);
fixAttributeAtPath(pmml, REGRESSION_PATH_2, NAME, derivedFieldMap);
} else if (pmml.getGeneralRegressionModelArray().length > 0) {
fixAttributeAtPath(pmml, GR_PATH_1, NAME, derivedFieldMap);
fixAttributeAtPath(pmml, GR_PATH_2, LABEL, derivedFieldMap);
fixAttributeAtPath(pmml, GR_PATH_3, PREDICTOR_NAME, derivedFieldMap);
}
// else do nothing as no model exists yet
// --------------------------------------------------
PMMLPortObjectSpecCreator creator = new PMMLPortObjectSpecCreator(this, m_spec.getDataTableSpec());
creator.addPreprocColNames(colNames);
m_spec = creator.createSpec();
}
use of org.dmg.pmml.DataDictionaryDocument.DataDictionary in project knime-core by knime.
the class PMMLDataDictionaryTranslator method addColSpecsForDataFields.
/**
* @param pmmlDoc the PMML document to analyze
* @param colSpecs the list to add the data column specs to
*/
private void addColSpecsForDataFields(final PMMLDocument pmmlDoc, final List<DataColumnSpec> colSpecs) {
DataDictionary dict = pmmlDoc.getPMML().getDataDictionary();
for (DataField dataField : dict.getDataFieldArray()) {
String name = dataField.getName();
DataType dataType = getKNIMEDataType(dataField.getDataType());
DataColumnSpecCreator specCreator = new DataColumnSpecCreator(name, dataType);
DataColumnDomain domain = null;
if (dataType.isCompatible(NominalValue.class)) {
Value[] valueArray = dataField.getValueArray();
DataCell[] cells;
if (DataType.getType(StringCell.class).equals(dataType)) {
if (dataField.getIntervalArray().length > 0) {
throw new IllegalArgumentException("Intervals cannot be defined for Strings.");
}
cells = new StringCell[valueArray.length];
if (valueArray.length > 0) {
for (int j = 0; j < cells.length; j++) {
cells[j] = new StringCell(valueArray[j].getValue());
}
}
domain = new DataColumnDomainCreator(cells).createDomain();
}
} else if (dataType.isCompatible(DoubleValue.class)) {
Double leftMargin = null;
Double rightMargin = null;
Interval[] intervalArray = dataField.getIntervalArray();
if (intervalArray != null && intervalArray.length > 0) {
Interval interval = dataField.getIntervalArray(0);
leftMargin = interval.getLeftMargin();
rightMargin = interval.getRightMargin();
} else if (dataField.getValueArray() != null && dataField.getValueArray().length > 0) {
// try to derive the bounds from the values
Value[] valueArray = dataField.getValueArray();
List<Double> values = new ArrayList<Double>();
for (int j = 0; j < valueArray.length; j++) {
String value = "";
try {
value = valueArray[j].getValue();
values.add(Double.parseDouble(value));
} catch (Exception e) {
throw new IllegalArgumentException("Skipping domain calculation. " + "Value \"" + value + "\" cannot be cast to double.");
}
}
leftMargin = Collections.min(values);
rightMargin = Collections.max(values);
}
if (leftMargin != null && rightMargin != null) {
// set the bounds of the domain if available
DataCell lowerBound = null;
DataCell upperBound = null;
if (DataType.getType(IntCell.class).equals(dataType)) {
lowerBound = new IntCell(leftMargin.intValue());
upperBound = new IntCell(rightMargin.intValue());
} else if (DataType.getType(DoubleCell.class).equals(dataType)) {
lowerBound = new DoubleCell(leftMargin);
upperBound = new DoubleCell(rightMargin);
}
domain = new DataColumnDomainCreator(lowerBound, upperBound).createDomain();
} else {
domain = new DataColumnDomainCreator().createDomain();
}
}
specCreator.setDomain(domain);
colSpecs.add(specCreator.createSpec());
m_dictFields.add(name);
}
}
use of org.dmg.pmml.DataDictionaryDocument.DataDictionary in project knime-core by knime.
the class DataColumnSpecFilterPMMLNodeModel method createPMMLOut.
private PMMLPortObject createPMMLOut(final PMMLPortObject pmmlIn, final DataTableSpec outSpec, final FilterResult res) throws XmlException {
StringBuffer warningBuffer = new StringBuffer();
if (pmmlIn == null) {
return new PMMLPortObject(createPMMLSpec(null, outSpec, res));
} else {
PMMLDocument pmmldoc;
try (LockedSupplier<Document> supplier = pmmlIn.getPMMLValue().getDocumentSupplier()) {
pmmldoc = PMMLDocument.Factory.parse(supplier.get());
}
// Inspect models to check if they use any excluded columns
List<PMMLModelWrapper> models = PMMLModelWrapper.getModelListFromPMMLDocument(pmmldoc);
for (PMMLModelWrapper model : models) {
MiningSchema ms = model.getMiningSchema();
for (MiningField mf : ms.getMiningFieldList()) {
if (isExcluded(mf.getName(), res)) {
if (warningBuffer.length() != 0) {
warningBuffer.append("\n");
}
warningBuffer.append(model.getModelType().name() + " uses excluded column " + mf.getName());
}
}
}
ArrayList<String> warningFields = new ArrayList<String>();
PMML pmml = pmmldoc.getPMML();
// Now check the transformations if they exist
if (pmml.getTransformationDictionary() != null) {
for (DerivedField df : pmml.getTransformationDictionary().getDerivedFieldList()) {
FieldRef fr = df.getFieldRef();
if (fr != null && isExcluded(fr.getField(), res)) {
warningFields.add(fr.getField());
}
Aggregate a = df.getAggregate();
if (a != null && isExcluded(a.getField(), res)) {
warningFields.add(a.getField());
}
Apply ap = df.getApply();
if (ap != null) {
for (FieldRef fieldRef : ap.getFieldRefList()) {
if (isExcluded(fieldRef.getField(), res)) {
warningFields.add(fieldRef.getField());
break;
}
}
}
Discretize d = df.getDiscretize();
if (d != null && isExcluded(d.getField(), res)) {
warningFields.add(d.getField());
}
MapValues mv = df.getMapValues();
if (mv != null) {
for (FieldColumnPair fcp : mv.getFieldColumnPairList()) {
if (isExcluded(fcp.getField(), res)) {
warningFields.add(fcp.getField());
}
}
}
NormContinuous nc = df.getNormContinuous();
if (nc != null && isExcluded(nc.getField(), res)) {
warningFields.add(nc.getField());
}
NormDiscrete nd = df.getNormDiscrete();
if (nd != null && isExcluded(nd.getField(), res)) {
warningFields.add(nd.getField());
}
}
}
DataDictionary dict = pmml.getDataDictionary();
List<DataField> fields = dict.getDataFieldList();
// Apply filter to spec
int numFields = 0;
for (int i = fields.size() - 1; i >= 0; i--) {
if (isExcluded(fields.get(i).getName(), res)) {
dict.removeDataField(i);
} else {
numFields++;
}
}
dict.setNumberOfFields(new BigInteger(Integer.toString(numFields)));
pmml.setDataDictionary(dict);
pmmldoc.setPMML(pmml);
// generate warnings and set as warning message
for (String s : warningFields) {
if (warningBuffer.length() != 0) {
warningBuffer.append("\n");
}
warningBuffer.append("Transformation dictionary uses excluded column " + s);
}
if (warningBuffer.length() > 0) {
setWarningMessage(warningBuffer.toString().trim());
}
PMMLPortObject outport = null;
try {
outport = new PMMLPortObject(createPMMLSpec(pmmlIn.getSpec(), outSpec, res), pmmldoc);
} catch (IllegalArgumentException e) {
if (res.getIncludes().length == 0) {
throw new IllegalArgumentException("Excluding all columns produces invalid PMML", e);
} else {
throw e;
}
}
return outport;
}
}
use of org.dmg.pmml.DataDictionaryDocument.DataDictionary in project knime-core by knime.
the class PMMLDataDictionaryTranslator method exportTo.
/**
* Adds a data dictionary to the PMML document based on the
* {@link DataTableSpec}.
*
* @param pmmlDoc the PMML document to export to
* @param dts the data table spec
* @return the schema type of the exported schema if applicable, otherwise
* null
* @see #exportTo(PMMLDocument, PMMLPortObjectSpec)
*/
public SchemaType exportTo(final PMMLDocument pmmlDoc, final DataTableSpec dts) {
DataDictionary dict = DataDictionary.Factory.newInstance();
dict.setNumberOfFields(BigInteger.valueOf(dts.getNumColumns()));
DataField dataField;
for (DataColumnSpec colSpec : dts) {
dataField = dict.addNewDataField();
dataField.setName(colSpec.getName());
DataType dataType = colSpec.getType();
dataField.setOptype(getOptype(dataType));
dataField.setDataType(getPMMLDataType(dataType));
// Value
if (colSpec.getType().isCompatible(NominalValue.class) && colSpec.getDomain().hasValues()) {
for (DataCell possVal : colSpec.getDomain().getValues()) {
Value value = dataField.addNewValue();
value.setValue(possVal.toString());
}
} else if (colSpec.getType().isCompatible(DoubleValue.class) && colSpec.getDomain().hasBounds()) {
Interval interval = dataField.addNewInterval();
interval.setClosure(Interval.Closure.CLOSED_CLOSED);
interval.setLeftMargin(((DoubleValue) colSpec.getDomain().getLowerBound()).getDoubleValue());
interval.setRightMargin(((DoubleValue) colSpec.getDomain().getUpperBound()).getDoubleValue());
}
}
pmmlDoc.getPMML().setDataDictionary(dict);
// no schematype available yet
return null;
}
Aggregations