Search in sources :

Example 1 with DataDictionary

use of org.dmg.pmml.DataDictionaryDocument.DataDictionary in project knime-core by knime.

the class PMMLRuleTranslator method initDataDictionary.

/**
 * Inits {@link #m_dataDictionary} based on the {@code pmmlDoc} document.
 *
 * @param pmmlDoc A {@link PMMLDocument}.
 */
private void initDataDictionary(final PMMLDocument pmmlDoc) {
    DataDictionary dd = pmmlDoc.getPMML().getDataDictionary();
    if (dd == null) {
        m_dataDictionary = Collections.emptyMap();
        return;
    }
    Map<String, List<String>> dataDictionary = new LinkedHashMap<String, List<String>>(dd.sizeOfDataFieldArray() * 2);
    for (DataField df : dd.getDataFieldList()) {
        List<String> list = new ArrayList<String>(df.sizeOfValueArray());
        for (Value val : df.getValueList()) {
            list.add(val.getValue());
        }
        dataDictionary.put(df.getName(), Collections.unmodifiableList(list));
    }
    m_dataDictionary = Collections.unmodifiableMap(dataDictionary);
}
Also used : DataField(org.dmg.pmml.DataFieldDocument.DataField) ArrayList(java.util.ArrayList) Value(org.dmg.pmml.ValueDocument.Value) List(java.util.List) ArrayList(java.util.ArrayList) LinkedList(java.util.LinkedList) DataDictionary(org.dmg.pmml.DataDictionaryDocument.DataDictionary) LinkedHashMap(java.util.LinkedHashMap)

Example 2 with DataDictionary

use of org.dmg.pmml.DataDictionaryDocument.DataDictionary in project knime-core by knime.

the class PMMLPortObject method addGlobalTransformations.

/**
 * Adds global transformations to the PMML document. Only DerivedField
 * elements are supported so far. If no global transformations are set so
 * far the dictionary is set as new transformation dictionary, otherwise
 * all contained transformations are appended to the existing one.
 *
 * @param dictionary the transformation dictionary that contains the
 *      transformations to be added
 */
public void addGlobalTransformations(final TransformationDictionary dictionary) {
    // add the transformations to the TransformationDictionary
    if (dictionary.getDefineFunctionArray().length > 0) {
        throw new IllegalArgumentException("DefineFunctions are not " + "supported so far. Only derived fields are allowed.");
    }
    TransformationDictionary dict = m_pmmlDoc.getPMML().getTransformationDictionary();
    if (dict == null) {
        m_pmmlDoc.getPMML().setTransformationDictionary(dictionary);
        dict = m_pmmlDoc.getPMML().getTransformationDictionary();
    } else {
        // append the transformations to the existing dictionary
        DerivedField[] existingFields = dict.getDerivedFieldArray();
        DerivedField[] result = appendDerivedFields(existingFields, dictionary.getDerivedFieldArray());
        dict.setDerivedFieldArray(result);
    }
    DerivedField[] df = dict.getDerivedFieldArray();
    List<String> colNames = new ArrayList<String>(df.length);
    Set<String> dfNames = new HashSet<String>();
    for (int i = 0; i < df.length; i++) {
        String derivedName = df[i].getName();
        if (dfNames.contains(derivedName)) {
            throw new IllegalArgumentException("Derived field name \"" + derivedName + "\" is not unique.");
        }
        dfNames.add(derivedName);
        String displayName = df[i].getDisplayName();
        colNames.add(displayName == null ? derivedName : displayName);
    }
    /* Remove data fields from data dictionary that where created as a
         * derived field. In KNIME the origin of columns is not distinguished
         * and all columns are added to the data dictionary. But in PMML this
         * results in duplicate entries. Those columns should only appear once
         * as derived field in the transformation dictionary or local
         * transformations. */
    DataDictionary dataDict = m_pmmlDoc.getPMML().getDataDictionary();
    DataField[] dataFieldArray = dataDict.getDataFieldArray();
    List<DataField> dataFields = new ArrayList<DataField>(Arrays.asList(dataFieldArray));
    for (DataField dataField : dataFieldArray) {
        if (dfNames.contains(dataField.getName())) {
            dataFields.remove(dataField);
        }
    }
    dataDict.setDataFieldArray(dataFields.toArray(new DataField[0]));
    // update the number of fields
    dataDict.setNumberOfFields(BigInteger.valueOf(dataFields.size()));
    // -------------------------------------------------
    // update field names in the model if applicable
    DerivedFieldMapper dfm = new DerivedFieldMapper(df);
    Map<String, String> derivedFieldMap = dfm.getDerivedFieldMap();
    /* Use XPATH to update field names in the model and move the derived
         * fields to local transformations. */
    PMML pmml = m_pmmlDoc.getPMML();
    if (pmml.getTreeModelArray().length > 0) {
        fixAttributeAtPath(pmml, TREE_PATH, FIELD, derivedFieldMap);
    } else if (pmml.getClusteringModelArray().length > 0) {
        fixAttributeAtPath(pmml, CLUSTERING_PATH, FIELD, derivedFieldMap);
    } else if (pmml.getNeuralNetworkArray().length > 0) {
        fixAttributeAtPath(pmml, NN_PATH, FIELD, derivedFieldMap);
    } else if (pmml.getSupportVectorMachineModelArray().length > 0) {
        fixAttributeAtPath(pmml, SVM_PATH, FIELD, derivedFieldMap);
    } else if (pmml.getRegressionModelArray().length > 0) {
        fixAttributeAtPath(pmml, REGRESSION_PATH_1, FIELD, derivedFieldMap);
        fixAttributeAtPath(pmml, REGRESSION_PATH_2, NAME, derivedFieldMap);
    } else if (pmml.getGeneralRegressionModelArray().length > 0) {
        fixAttributeAtPath(pmml, GR_PATH_1, NAME, derivedFieldMap);
        fixAttributeAtPath(pmml, GR_PATH_2, LABEL, derivedFieldMap);
        fixAttributeAtPath(pmml, GR_PATH_3, PREDICTOR_NAME, derivedFieldMap);
    }
    // else do nothing as no model exists yet
    // --------------------------------------------------
    PMMLPortObjectSpecCreator creator = new PMMLPortObjectSpecCreator(this, m_spec.getDataTableSpec());
    creator.addPreprocColNames(colNames);
    m_spec = creator.createSpec();
}
Also used : TransformationDictionary(org.dmg.pmml.TransformationDictionaryDocument.TransformationDictionary) ArrayList(java.util.ArrayList) DataDictionary(org.dmg.pmml.DataDictionaryDocument.DataDictionary) DerivedFieldMapper(org.knime.core.node.port.pmml.preproc.DerivedFieldMapper) DataField(org.dmg.pmml.DataFieldDocument.DataField) PMML(org.dmg.pmml.PMMLDocument.PMML) DerivedField(org.dmg.pmml.DerivedFieldDocument.DerivedField) HashSet(java.util.HashSet) LinkedHashSet(java.util.LinkedHashSet)

Example 3 with DataDictionary

use of org.dmg.pmml.DataDictionaryDocument.DataDictionary in project knime-core by knime.

the class PMMLDataDictionaryTranslator method addColSpecsForDataFields.

/**
 * @param pmmlDoc the PMML document to analyze
 * @param colSpecs the list to add the data column specs to
 */
private void addColSpecsForDataFields(final PMMLDocument pmmlDoc, final List<DataColumnSpec> colSpecs) {
    DataDictionary dict = pmmlDoc.getPMML().getDataDictionary();
    for (DataField dataField : dict.getDataFieldArray()) {
        String name = dataField.getName();
        DataType dataType = getKNIMEDataType(dataField.getDataType());
        DataColumnSpecCreator specCreator = new DataColumnSpecCreator(name, dataType);
        DataColumnDomain domain = null;
        if (dataType.isCompatible(NominalValue.class)) {
            Value[] valueArray = dataField.getValueArray();
            DataCell[] cells;
            if (DataType.getType(StringCell.class).equals(dataType)) {
                if (dataField.getIntervalArray().length > 0) {
                    throw new IllegalArgumentException("Intervals cannot be defined for Strings.");
                }
                cells = new StringCell[valueArray.length];
                if (valueArray.length > 0) {
                    for (int j = 0; j < cells.length; j++) {
                        cells[j] = new StringCell(valueArray[j].getValue());
                    }
                }
                domain = new DataColumnDomainCreator(cells).createDomain();
            }
        } else if (dataType.isCompatible(DoubleValue.class)) {
            Double leftMargin = null;
            Double rightMargin = null;
            Interval[] intervalArray = dataField.getIntervalArray();
            if (intervalArray != null && intervalArray.length > 0) {
                Interval interval = dataField.getIntervalArray(0);
                leftMargin = interval.getLeftMargin();
                rightMargin = interval.getRightMargin();
            } else if (dataField.getValueArray() != null && dataField.getValueArray().length > 0) {
                // try to derive the bounds from the values
                Value[] valueArray = dataField.getValueArray();
                List<Double> values = new ArrayList<Double>();
                for (int j = 0; j < valueArray.length; j++) {
                    String value = "";
                    try {
                        value = valueArray[j].getValue();
                        values.add(Double.parseDouble(value));
                    } catch (Exception e) {
                        throw new IllegalArgumentException("Skipping domain calculation. " + "Value \"" + value + "\" cannot be cast to double.");
                    }
                }
                leftMargin = Collections.min(values);
                rightMargin = Collections.max(values);
            }
            if (leftMargin != null && rightMargin != null) {
                // set the bounds of the domain if available
                DataCell lowerBound = null;
                DataCell upperBound = null;
                if (DataType.getType(IntCell.class).equals(dataType)) {
                    lowerBound = new IntCell(leftMargin.intValue());
                    upperBound = new IntCell(rightMargin.intValue());
                } else if (DataType.getType(DoubleCell.class).equals(dataType)) {
                    lowerBound = new DoubleCell(leftMargin);
                    upperBound = new DoubleCell(rightMargin);
                }
                domain = new DataColumnDomainCreator(lowerBound, upperBound).createDomain();
            } else {
                domain = new DataColumnDomainCreator().createDomain();
            }
        }
        specCreator.setDomain(domain);
        colSpecs.add(specCreator.createSpec());
        m_dictFields.add(name);
    }
}
Also used : DataColumnSpecCreator(org.knime.core.data.DataColumnSpecCreator) DoubleCell(org.knime.core.data.def.DoubleCell) ArrayList(java.util.ArrayList) DataColumnDomainCreator(org.knime.core.data.DataColumnDomainCreator) DataDictionary(org.dmg.pmml.DataDictionaryDocument.DataDictionary) IntCell(org.knime.core.data.def.IntCell) DataColumnDomain(org.knime.core.data.DataColumnDomain) DataField(org.dmg.pmml.DataFieldDocument.DataField) StringCell(org.knime.core.data.def.StringCell) DoubleValue(org.knime.core.data.DoubleValue) NominalValue(org.knime.core.data.NominalValue) BooleanValue(org.knime.core.data.BooleanValue) IntValue(org.knime.core.data.IntValue) Value(org.dmg.pmml.ValueDocument.Value) DoubleValue(org.knime.core.data.DoubleValue) DataType(org.knime.core.data.DataType) DataCell(org.knime.core.data.DataCell) Interval(org.dmg.pmml.IntervalDocument.Interval)

Example 4 with DataDictionary

use of org.dmg.pmml.DataDictionaryDocument.DataDictionary in project knime-core by knime.

the class DataColumnSpecFilterPMMLNodeModel method createPMMLOut.

private PMMLPortObject createPMMLOut(final PMMLPortObject pmmlIn, final DataTableSpec outSpec, final FilterResult res) throws XmlException {
    StringBuffer warningBuffer = new StringBuffer();
    if (pmmlIn == null) {
        return new PMMLPortObject(createPMMLSpec(null, outSpec, res));
    } else {
        PMMLDocument pmmldoc;
        try (LockedSupplier<Document> supplier = pmmlIn.getPMMLValue().getDocumentSupplier()) {
            pmmldoc = PMMLDocument.Factory.parse(supplier.get());
        }
        // Inspect models to check if they use any excluded columns
        List<PMMLModelWrapper> models = PMMLModelWrapper.getModelListFromPMMLDocument(pmmldoc);
        for (PMMLModelWrapper model : models) {
            MiningSchema ms = model.getMiningSchema();
            for (MiningField mf : ms.getMiningFieldList()) {
                if (isExcluded(mf.getName(), res)) {
                    if (warningBuffer.length() != 0) {
                        warningBuffer.append("\n");
                    }
                    warningBuffer.append(model.getModelType().name() + " uses excluded column " + mf.getName());
                }
            }
        }
        ArrayList<String> warningFields = new ArrayList<String>();
        PMML pmml = pmmldoc.getPMML();
        // Now check the transformations if they exist
        if (pmml.getTransformationDictionary() != null) {
            for (DerivedField df : pmml.getTransformationDictionary().getDerivedFieldList()) {
                FieldRef fr = df.getFieldRef();
                if (fr != null && isExcluded(fr.getField(), res)) {
                    warningFields.add(fr.getField());
                }
                Aggregate a = df.getAggregate();
                if (a != null && isExcluded(a.getField(), res)) {
                    warningFields.add(a.getField());
                }
                Apply ap = df.getApply();
                if (ap != null) {
                    for (FieldRef fieldRef : ap.getFieldRefList()) {
                        if (isExcluded(fieldRef.getField(), res)) {
                            warningFields.add(fieldRef.getField());
                            break;
                        }
                    }
                }
                Discretize d = df.getDiscretize();
                if (d != null && isExcluded(d.getField(), res)) {
                    warningFields.add(d.getField());
                }
                MapValues mv = df.getMapValues();
                if (mv != null) {
                    for (FieldColumnPair fcp : mv.getFieldColumnPairList()) {
                        if (isExcluded(fcp.getField(), res)) {
                            warningFields.add(fcp.getField());
                        }
                    }
                }
                NormContinuous nc = df.getNormContinuous();
                if (nc != null && isExcluded(nc.getField(), res)) {
                    warningFields.add(nc.getField());
                }
                NormDiscrete nd = df.getNormDiscrete();
                if (nd != null && isExcluded(nd.getField(), res)) {
                    warningFields.add(nd.getField());
                }
            }
        }
        DataDictionary dict = pmml.getDataDictionary();
        List<DataField> fields = dict.getDataFieldList();
        // Apply filter to spec
        int numFields = 0;
        for (int i = fields.size() - 1; i >= 0; i--) {
            if (isExcluded(fields.get(i).getName(), res)) {
                dict.removeDataField(i);
            } else {
                numFields++;
            }
        }
        dict.setNumberOfFields(new BigInteger(Integer.toString(numFields)));
        pmml.setDataDictionary(dict);
        pmmldoc.setPMML(pmml);
        // generate warnings and set as warning message
        for (String s : warningFields) {
            if (warningBuffer.length() != 0) {
                warningBuffer.append("\n");
            }
            warningBuffer.append("Transformation dictionary uses excluded column " + s);
        }
        if (warningBuffer.length() > 0) {
            setWarningMessage(warningBuffer.toString().trim());
        }
        PMMLPortObject outport = null;
        try {
            outport = new PMMLPortObject(createPMMLSpec(pmmlIn.getSpec(), outSpec, res), pmmldoc);
        } catch (IllegalArgumentException e) {
            if (res.getIncludes().length == 0) {
                throw new IllegalArgumentException("Excluding all columns produces invalid PMML", e);
            } else {
                throw e;
            }
        }
        return outport;
    }
}
Also used : MiningField(org.dmg.pmml.MiningFieldDocument.MiningField) NormContinuous(org.dmg.pmml.NormContinuousDocument.NormContinuous) Apply(org.dmg.pmml.ApplyDocument.Apply) ArrayList(java.util.ArrayList) FieldColumnPair(org.dmg.pmml.FieldColumnPairDocument.FieldColumnPair) PMMLDocument(org.dmg.pmml.PMMLDocument) Document(org.w3c.dom.Document) MapValues(org.dmg.pmml.MapValuesDocument.MapValues) Discretize(org.dmg.pmml.DiscretizeDocument.Discretize) FieldRef(org.dmg.pmml.FieldRefDocument.FieldRef) DataDictionary(org.dmg.pmml.DataDictionaryDocument.DataDictionary) PMMLModelWrapper(org.knime.core.node.port.pmml.PMMLModelWrapper) NormDiscrete(org.dmg.pmml.NormDiscreteDocument.NormDiscrete) MiningSchema(org.dmg.pmml.MiningSchemaDocument.MiningSchema) DataField(org.dmg.pmml.DataFieldDocument.DataField) PMMLPortObject(org.knime.core.node.port.pmml.PMMLPortObject) PMML(org.dmg.pmml.PMMLDocument.PMML) BigInteger(java.math.BigInteger) PMMLDocument(org.dmg.pmml.PMMLDocument) Aggregate(org.dmg.pmml.AggregateDocument.Aggregate) DerivedField(org.dmg.pmml.DerivedFieldDocument.DerivedField)

Example 5 with DataDictionary

use of org.dmg.pmml.DataDictionaryDocument.DataDictionary in project knime-core by knime.

the class PMMLDataDictionaryTranslator method exportTo.

/**
 * Adds a data dictionary to the PMML document based on the
 * {@link DataTableSpec}.
 *
 * @param pmmlDoc the PMML document to export to
 * @param dts the data table spec
 * @return the schema type of the exported schema if applicable, otherwise
 *         null
 * @see #exportTo(PMMLDocument, PMMLPortObjectSpec)
 */
public SchemaType exportTo(final PMMLDocument pmmlDoc, final DataTableSpec dts) {
    DataDictionary dict = DataDictionary.Factory.newInstance();
    dict.setNumberOfFields(BigInteger.valueOf(dts.getNumColumns()));
    DataField dataField;
    for (DataColumnSpec colSpec : dts) {
        dataField = dict.addNewDataField();
        dataField.setName(colSpec.getName());
        DataType dataType = colSpec.getType();
        dataField.setOptype(getOptype(dataType));
        dataField.setDataType(getPMMLDataType(dataType));
        // Value
        if (colSpec.getType().isCompatible(NominalValue.class) && colSpec.getDomain().hasValues()) {
            for (DataCell possVal : colSpec.getDomain().getValues()) {
                Value value = dataField.addNewValue();
                value.setValue(possVal.toString());
            }
        } else if (colSpec.getType().isCompatible(DoubleValue.class) && colSpec.getDomain().hasBounds()) {
            Interval interval = dataField.addNewInterval();
            interval.setClosure(Interval.Closure.CLOSED_CLOSED);
            interval.setLeftMargin(((DoubleValue) colSpec.getDomain().getLowerBound()).getDoubleValue());
            interval.setRightMargin(((DoubleValue) colSpec.getDomain().getUpperBound()).getDoubleValue());
        }
    }
    pmmlDoc.getPMML().setDataDictionary(dict);
    // no schematype available yet
    return null;
}
Also used : DataColumnSpec(org.knime.core.data.DataColumnSpec) DataField(org.dmg.pmml.DataFieldDocument.DataField) DoubleValue(org.knime.core.data.DoubleValue) NominalValue(org.knime.core.data.NominalValue) NominalValue(org.knime.core.data.NominalValue) BooleanValue(org.knime.core.data.BooleanValue) IntValue(org.knime.core.data.IntValue) Value(org.dmg.pmml.ValueDocument.Value) DoubleValue(org.knime.core.data.DoubleValue) DataType(org.knime.core.data.DataType) DataCell(org.knime.core.data.DataCell) DataDictionary(org.dmg.pmml.DataDictionaryDocument.DataDictionary) Interval(org.dmg.pmml.IntervalDocument.Interval)

Aggregations

DataDictionary (org.dmg.pmml.DataDictionaryDocument.DataDictionary)5 DataField (org.dmg.pmml.DataFieldDocument.DataField)5 ArrayList (java.util.ArrayList)4 Value (org.dmg.pmml.ValueDocument.Value)3 DerivedField (org.dmg.pmml.DerivedFieldDocument.DerivedField)2 Interval (org.dmg.pmml.IntervalDocument.Interval)2 PMML (org.dmg.pmml.PMMLDocument.PMML)2 BooleanValue (org.knime.core.data.BooleanValue)2 DataCell (org.knime.core.data.DataCell)2 DataType (org.knime.core.data.DataType)2 DoubleValue (org.knime.core.data.DoubleValue)2 IntValue (org.knime.core.data.IntValue)2 NominalValue (org.knime.core.data.NominalValue)2 BigInteger (java.math.BigInteger)1 HashSet (java.util.HashSet)1 LinkedHashMap (java.util.LinkedHashMap)1 LinkedHashSet (java.util.LinkedHashSet)1 LinkedList (java.util.LinkedList)1 List (java.util.List)1 Aggregate (org.dmg.pmml.AggregateDocument.Aggregate)1