Search in sources :

Example 1 with Apply

use of org.dmg.pmml.ApplyDocument.Apply in project knime-core by knime.

the class PMMLMany2OneTranslator method createDerivedField.

private DerivedField createDerivedField() {
    final DerivedField derivedField = DerivedField.Factory.newInstance();
    derivedField.setName(m_appendedCol);
    derivedField.setDataType(DATATYPE.STRING);
    derivedField.setOptype(OPTYPE.CATEGORICAL);
    Apply parentApply = null;
    for (String col : m_sourceCols) {
        Apply ifApply;
        if (parentApply == null) {
            ifApply = derivedField.addNewApply();
        } else {
            ifApply = parentApply.addNewApply();
        }
        ifApply.setFunction("if");
        Apply innerIf = ifApply.addNewApply();
        innerIf.setFunction("equal");
        innerIf.addNewFieldRef().setField(col);
        if (m_method == IncludeMethod.Maximum || m_method == IncludeMethod.Minimum) {
            Apply a = innerIf.addNewApply();
            a.setFunction(IncludeMethod.Maximum == m_method ? "max" : "min");
            for (String s : m_sourceCols) {
                a.addNewFieldRef().setField(s);
            }
        } else {
            // if (m_method == IncludeMethod.Binary) {
            innerIf.addNewConstant().setStringValue("1");
        }
        ifApply.addNewConstant().setStringValue(col);
        parentApply = ifApply;
    }
    if (parentApply != null) {
        parentApply.addNewConstant().setStringValue("missing");
    }
    return derivedField;
}
Also used : Apply(org.dmg.pmml.ApplyDocument.Apply) DerivedField(org.dmg.pmml.DerivedFieldDocument.DerivedField)

Example 2 with Apply

use of org.dmg.pmml.ApplyDocument.Apply in project knime-core by knime.

the class DataColumnSpecFilterPMMLNodeModel method createPMMLOut.

private PMMLPortObject createPMMLOut(final PMMLPortObject pmmlIn, final DataTableSpec outSpec, final FilterResult res) throws XmlException {
    StringBuffer warningBuffer = new StringBuffer();
    if (pmmlIn == null) {
        return new PMMLPortObject(createPMMLSpec(null, outSpec, res));
    } else {
        PMMLDocument pmmldoc;
        try (LockedSupplier<Document> supplier = pmmlIn.getPMMLValue().getDocumentSupplier()) {
            pmmldoc = PMMLDocument.Factory.parse(supplier.get());
        }
        // Inspect models to check if they use any excluded columns
        List<PMMLModelWrapper> models = PMMLModelWrapper.getModelListFromPMMLDocument(pmmldoc);
        for (PMMLModelWrapper model : models) {
            MiningSchema ms = model.getMiningSchema();
            for (MiningField mf : ms.getMiningFieldList()) {
                if (isExcluded(mf.getName(), res)) {
                    if (warningBuffer.length() != 0) {
                        warningBuffer.append("\n");
                    }
                    warningBuffer.append(model.getModelType().name() + " uses excluded column " + mf.getName());
                }
            }
        }
        ArrayList<String> warningFields = new ArrayList<String>();
        PMML pmml = pmmldoc.getPMML();
        // Now check the transformations if they exist
        if (pmml.getTransformationDictionary() != null) {
            for (DerivedField df : pmml.getTransformationDictionary().getDerivedFieldList()) {
                FieldRef fr = df.getFieldRef();
                if (fr != null && isExcluded(fr.getField(), res)) {
                    warningFields.add(fr.getField());
                }
                Aggregate a = df.getAggregate();
                if (a != null && isExcluded(a.getField(), res)) {
                    warningFields.add(a.getField());
                }
                Apply ap = df.getApply();
                if (ap != null) {
                    for (FieldRef fieldRef : ap.getFieldRefList()) {
                        if (isExcluded(fieldRef.getField(), res)) {
                            warningFields.add(fieldRef.getField());
                            break;
                        }
                    }
                }
                Discretize d = df.getDiscretize();
                if (d != null && isExcluded(d.getField(), res)) {
                    warningFields.add(d.getField());
                }
                MapValues mv = df.getMapValues();
                if (mv != null) {
                    for (FieldColumnPair fcp : mv.getFieldColumnPairList()) {
                        if (isExcluded(fcp.getField(), res)) {
                            warningFields.add(fcp.getField());
                        }
                    }
                }
                NormContinuous nc = df.getNormContinuous();
                if (nc != null && isExcluded(nc.getField(), res)) {
                    warningFields.add(nc.getField());
                }
                NormDiscrete nd = df.getNormDiscrete();
                if (nd != null && isExcluded(nd.getField(), res)) {
                    warningFields.add(nd.getField());
                }
            }
        }
        DataDictionary dict = pmml.getDataDictionary();
        List<DataField> fields = dict.getDataFieldList();
        // Apply filter to spec
        int numFields = 0;
        for (int i = fields.size() - 1; i >= 0; i--) {
            if (isExcluded(fields.get(i).getName(), res)) {
                dict.removeDataField(i);
            } else {
                numFields++;
            }
        }
        dict.setNumberOfFields(new BigInteger(Integer.toString(numFields)));
        pmml.setDataDictionary(dict);
        pmmldoc.setPMML(pmml);
        // generate warnings and set as warning message
        for (String s : warningFields) {
            if (warningBuffer.length() != 0) {
                warningBuffer.append("\n");
            }
            warningBuffer.append("Transformation dictionary uses excluded column " + s);
        }
        if (warningBuffer.length() > 0) {
            setWarningMessage(warningBuffer.toString().trim());
        }
        PMMLPortObject outport = null;
        try {
            outport = new PMMLPortObject(createPMMLSpec(pmmlIn.getSpec(), outSpec, res), pmmldoc);
        } catch (IllegalArgumentException e) {
            if (res.getIncludes().length == 0) {
                throw new IllegalArgumentException("Excluding all columns produces invalid PMML", e);
            } else {
                throw e;
            }
        }
        return outport;
    }
}
Also used : MiningField(org.dmg.pmml.MiningFieldDocument.MiningField) NormContinuous(org.dmg.pmml.NormContinuousDocument.NormContinuous) Apply(org.dmg.pmml.ApplyDocument.Apply) ArrayList(java.util.ArrayList) FieldColumnPair(org.dmg.pmml.FieldColumnPairDocument.FieldColumnPair) PMMLDocument(org.dmg.pmml.PMMLDocument) Document(org.w3c.dom.Document) MapValues(org.dmg.pmml.MapValuesDocument.MapValues) Discretize(org.dmg.pmml.DiscretizeDocument.Discretize) FieldRef(org.dmg.pmml.FieldRefDocument.FieldRef) DataDictionary(org.dmg.pmml.DataDictionaryDocument.DataDictionary) PMMLModelWrapper(org.knime.core.node.port.pmml.PMMLModelWrapper) NormDiscrete(org.dmg.pmml.NormDiscreteDocument.NormDiscrete) MiningSchema(org.dmg.pmml.MiningSchemaDocument.MiningSchema) DataField(org.dmg.pmml.DataFieldDocument.DataField) PMMLPortObject(org.knime.core.node.port.pmml.PMMLPortObject) PMML(org.dmg.pmml.PMMLDocument.PMML) BigInteger(java.math.BigInteger) PMMLDocument(org.dmg.pmml.PMMLDocument) Aggregate(org.dmg.pmml.AggregateDocument.Aggregate) DerivedField(org.dmg.pmml.DerivedFieldDocument.DerivedField)

Example 3 with Apply

use of org.dmg.pmml.ApplyDocument.Apply in project knime-core by knime.

the class PMMLMany2OneTranslator method createDerivedField.

private DerivedField createDerivedField() {
    final DerivedField derivedField = DerivedField.Factory.newInstance();
    derivedField.setName(m_appendedCol);
    derivedField.setDataType(DATATYPE.STRING);
    derivedField.setOptype(OPTYPE.CATEGORICAL);
    Apply parentApply = null;
    for (String col : m_sourceCols) {
        Apply ifApply;
        if (parentApply == null) {
            ifApply = derivedField.addNewApply();
        } else {
            ifApply = parentApply.addNewApply();
        }
        ifApply.setFunction("if");
        Apply innerIf = ifApply.addNewApply();
        innerIf.setFunction("equal");
        innerIf.addNewFieldRef().setField(col);
        if (m_method == IncludeMethod.Maximum || m_method == IncludeMethod.Minimum) {
            Apply a = innerIf.addNewApply();
            a.setFunction(IncludeMethod.Maximum == m_method ? "max" : "min");
            for (String s : m_sourceCols) {
                a.addNewFieldRef().setField(s);
            }
        } else {
            // if (m_method == IncludeMethod.Binary) {
            innerIf.addNewConstant().setStringValue("1");
        }
        ifApply.addNewConstant().setStringValue(col);
        parentApply = ifApply;
    }
    if (parentApply != null) {
        parentApply.addNewConstant().setStringValue("missing");
    }
    return derivedField;
}
Also used : Apply(org.dmg.pmml.ApplyDocument.Apply) DerivedField(org.dmg.pmml.DerivedFieldDocument.DerivedField)

Example 4 with Apply

use of org.dmg.pmml.ApplyDocument.Apply in project knime-core by knime.

the class MissingCellHandler method createValueReplacingDerivedField.

/**
 * Helper method for creating a derived field that replaces a field's value with a fixed value.
 * @param dataType the data type of the field.
 * @param value the replacement value for the field
 * @return the derived field
 */
protected DerivedField createValueReplacingDerivedField(final DATATYPE.Enum dataType, final String value) {
    DerivedField field = DerivedField.Factory.newInstance();
    if (dataType == org.dmg.pmml.DATATYPE.STRING || dataType == org.dmg.pmml.DATATYPE.BOOLEAN) {
        field.setOptype(org.dmg.pmml.OPTYPE.CATEGORICAL);
    } else {
        field.setOptype(org.dmg.pmml.OPTYPE.CONTINUOUS);
    }
    /*
         * Create the PMML equivalent of: "if fieldVal is missing then x else fieldVal"
         * <Apply function="if">
         *    <Apply function="isMissing">
         *        <FieldRef field="fieldVal"/>
         *    </Apply>
         *    <Constant dataType="___" value="x"/>
         *    <FieldRef field="fieldVal"/>
         * </Apply>
         */
    Apply ifApply = field.addNewApply();
    ifApply.setFunction(IF_FUNCTION_NAME);
    Apply isMissingApply = Apply.Factory.newInstance();
    FieldRef fieldRef = FieldRef.Factory.newInstance();
    fieldRef.setField(m_col.getName());
    isMissingApply.setFieldRefArray(new FieldRef[] { fieldRef });
    isMissingApply.setFunction(IS_MISSING_FUNCTION_NAME);
    ifApply.setApplyArray(new Apply[] { isMissingApply });
    Constant replacement = Constant.Factory.newInstance();
    replacement.setDataType(dataType);
    replacement.setStringValue(value);
    ifApply.setConstantArray(new Constant[] { replacement });
    ifApply.setFieldRefArray(new FieldRef[] { fieldRef });
    field.setDataType(dataType);
    field.setName(m_col.getName());
    field.setDisplayName(m_col.getName());
    return field;
}
Also used : FieldRef(org.dmg.pmml.FieldRefDocument.FieldRef) Apply(org.dmg.pmml.ApplyDocument.Apply) Constant(org.dmg.pmml.ConstantDocument.Constant) DerivedField(org.dmg.pmml.DerivedFieldDocument.DerivedField)

Example 5 with Apply

use of org.dmg.pmml.ApplyDocument.Apply in project knime-core by knime.

the class PMMLGeneralRegressionTranslator method exportTo.

/**
 * {@inheritDoc}
 */
@Override
public SchemaType exportTo(final PMMLDocument pmmlDoc, final PMMLPortObjectSpec spec) {
    m_nameMapper = new DerivedFieldMapper(pmmlDoc);
    GeneralRegressionModel reg = pmmlDoc.getPMML().addNewGeneralRegressionModel();
    final JsonObjectBuilder jsonBuilder = Json.createObjectBuilder();
    if (!m_content.getVectorLengths().isEmpty()) {
        LocalTransformations localTransformations = reg.addNewLocalTransformations();
        for (final Entry<? extends String, ? extends Integer> entry : m_content.getVectorLengths().entrySet()) {
            DataColumnSpec columnSpec = spec.getDataTableSpec().getColumnSpec(entry.getKey());
            if (columnSpec != null) {
                final DataType type = columnSpec.getType();
                final DataColumnProperties props = columnSpec.getProperties();
                final boolean bitVector = type.isCompatible(BitVectorValue.class) || (type.isCompatible(StringValue.class) && props.containsProperty("realType") && "BitVector".equals(props.getProperty("realType")));
                final boolean byteVector = type.isCompatible(ByteVectorValue.class) || (type.isCompatible(StringValue.class) && props.containsProperty("realType") && "ByteVector".equals(props.getProperty("realType")));
                final String lengthAsString;
                final int width;
                if (byteVector) {
                    lengthAsString = "3";
                    width = 4;
                } else if (bitVector) {
                    lengthAsString = "1";
                    width = 1;
                } else {
                    throw new UnsupportedOperationException("Not supported type: " + type + " for column: " + columnSpec);
                }
                for (int i = 0; i < entry.getValue().intValue(); ++i) {
                    final DerivedField derivedField = localTransformations.addNewDerivedField();
                    derivedField.setOptype(OPTYPE.CONTINUOUS);
                    derivedField.setDataType(DATATYPE.INTEGER);
                    derivedField.setName(entry.getKey() + "[" + i + "]");
                    Apply apply = derivedField.addNewApply();
                    apply.setFunction("substring");
                    apply.addNewFieldRef().setField(entry.getKey());
                    Constant from = apply.addNewConstant();
                    from.setDataType(DATATYPE.INTEGER);
                    from.setStringValue(bitVector ? Long.toString(entry.getValue().longValue() - i) : Long.toString(i * width + 1L));
                    Constant length = apply.addNewConstant();
                    length.setDataType(DATATYPE.INTEGER);
                    length.setStringValue(lengthAsString);
                }
            }
            jsonBuilder.add(entry.getKey(), entry.getValue().intValue());
        }
    }
    // PMMLPortObjectSpecCreator newSpecCreator = new PMMLPortObjectSpecCreator(spec);
    // newSpecCreator.addPreprocColNames(m_content.getVectorLengths().entrySet().stream()
    // .flatMap(
    // e -> IntStream.iterate(0, o -> o + 1).limit(e.getValue()).mapToObj(i -> e.getKey() + "[" + i + "]"))
    // .collect(Collectors.toList()));
    PMMLMiningSchemaTranslator.writeMiningSchema(spec, reg);
    // if (!m_content.getVectorLengths().isEmpty()) {
    // Extension miningExtension = reg.getMiningSchema().addNewExtension();
    // miningExtension.setExtender(EXTENDER);
    // miningExtension.setName(VECTOR_COLUMNS_WITH_LENGTH);
    // miningExtension.setValue(jsonBuilder.build().toString());
    // }
    reg.setModelType(getPMMLRegModelType(m_content.getModelType()));
    reg.setFunctionName(getPMMLMiningFunction(m_content.getFunctionName()));
    String algorithmName = m_content.getAlgorithmName();
    if (algorithmName != null && !algorithmName.isEmpty()) {
        reg.setAlgorithmName(algorithmName);
    }
    String modelName = m_content.getModelName();
    if (modelName != null && !modelName.isEmpty()) {
        reg.setModelName(modelName);
    }
    String targetReferenceCategory = m_content.getTargetReferenceCategory();
    if (targetReferenceCategory != null && !targetReferenceCategory.isEmpty()) {
        reg.setTargetReferenceCategory(targetReferenceCategory);
    }
    if (m_content.getOffsetValue() != null) {
        reg.setOffsetValue(m_content.getOffsetValue());
    }
    // add parameter list
    ParameterList paramList = reg.addNewParameterList();
    for (PMMLParameter p : m_content.getParameterList()) {
        Parameter param = paramList.addNewParameter();
        param.setName(p.getName());
        String label = p.getLabel();
        if (label != null) {
            param.setLabel(m_nameMapper.getDerivedFieldName(label));
        }
    }
    // add factor list
    FactorList factorList = reg.addNewFactorList();
    for (PMMLPredictor p : m_content.getFactorList()) {
        Predictor predictor = factorList.addNewPredictor();
        predictor.setName(m_nameMapper.getDerivedFieldName(p.getName()));
    }
    // add covariate list
    CovariateList covariateList = reg.addNewCovariateList();
    for (PMMLPredictor p : m_content.getCovariateList()) {
        Predictor predictor = covariateList.addNewPredictor();
        predictor.setName(m_nameMapper.getDerivedFieldName(p.getName()));
    }
    // add PPMatrix
    PPMatrix ppMatrix = reg.addNewPPMatrix();
    for (PMMLPPCell p : m_content.getPPMatrix()) {
        PPCell cell = ppMatrix.addNewPPCell();
        cell.setValue(p.getValue());
        cell.setPredictorName(m_nameMapper.getDerivedFieldName(p.getPredictorName()));
        cell.setParameterName(p.getParameterName());
        String targetCategory = p.getTargetCategory();
        if (targetCategory != null && !targetCategory.isEmpty()) {
            cell.setTargetCategory(targetCategory);
        }
    }
    // add CovMatrix
    if (m_content.getPCovMatrix().length > 0) {
        PCovMatrix pCovMatrix = reg.addNewPCovMatrix();
        for (PMMLPCovCell p : m_content.getPCovMatrix()) {
            PCovCell covCell = pCovMatrix.addNewPCovCell();
            covCell.setPRow(p.getPRow());
            covCell.setPCol(p.getPCol());
            String tCol = p.getTCol();
            String tRow = p.getTRow();
            if (tRow != null || tCol != null) {
                covCell.setTRow(tRow);
                covCell.setTCol(tCol);
            }
            covCell.setValue(p.getValue());
            String targetCategory = p.getTargetCategory();
            if (targetCategory != null && !targetCategory.isEmpty()) {
                covCell.setTargetCategory(targetCategory);
            }
        }
    }
    // add ParamMatrix
    ParamMatrix paramMatrix = reg.addNewParamMatrix();
    for (PMMLPCell p : m_content.getParamMatrix()) {
        PCell pCell = paramMatrix.addNewPCell();
        String targetCategory = p.getTargetCategory();
        if (targetCategory != null) {
            pCell.setTargetCategory(targetCategory);
        }
        pCell.setParameterName(p.getParameterName());
        pCell.setBeta(p.getBeta());
        Integer df = p.getDf();
        if (df != null) {
            pCell.setDf(BigInteger.valueOf(df));
        }
    }
    return GeneralRegressionModel.type;
}
Also used : Predictor(org.dmg.pmml.PredictorDocument.Predictor) Apply(org.dmg.pmml.ApplyDocument.Apply) Constant(org.dmg.pmml.ConstantDocument.Constant) PPCell(org.dmg.pmml.PPCellDocument.PPCell) ByteVectorValue(org.knime.core.data.vector.bytevector.ByteVectorValue) DerivedFieldMapper(org.knime.core.node.port.pmml.preproc.DerivedFieldMapper) DataColumnSpec(org.knime.core.data.DataColumnSpec) FactorList(org.dmg.pmml.FactorListDocument.FactorList) PPCell(org.dmg.pmml.PPCellDocument.PPCell) PCell(org.dmg.pmml.PCellDocument.PCell) DataType(org.knime.core.data.DataType) JsonObjectBuilder(javax.json.JsonObjectBuilder) DataColumnProperties(org.knime.core.data.DataColumnProperties) ParamMatrix(org.dmg.pmml.ParamMatrixDocument.ParamMatrix) PPMatrix(org.dmg.pmml.PPMatrixDocument.PPMatrix) CovariateList(org.dmg.pmml.CovariateListDocument.CovariateList) PCovMatrix(org.dmg.pmml.PCovMatrixDocument.PCovMatrix) BigInteger(java.math.BigInteger) LocalTransformations(org.dmg.pmml.LocalTransformationsDocument.LocalTransformations) PCovCell(org.dmg.pmml.PCovCellDocument.PCovCell) GeneralRegressionModel(org.dmg.pmml.GeneralRegressionModelDocument.GeneralRegressionModel) ParameterList(org.dmg.pmml.ParameterListDocument.ParameterList) Parameter(org.dmg.pmml.ParameterDocument.Parameter) BitVectorValue(org.knime.core.data.vector.bitvector.BitVectorValue) DerivedField(org.dmg.pmml.DerivedFieldDocument.DerivedField)

Aggregations

Apply (org.dmg.pmml.ApplyDocument.Apply)5 DerivedField (org.dmg.pmml.DerivedFieldDocument.DerivedField)5 BigInteger (java.math.BigInteger)2 Constant (org.dmg.pmml.ConstantDocument.Constant)2 FieldRef (org.dmg.pmml.FieldRefDocument.FieldRef)2 ArrayList (java.util.ArrayList)1 JsonObjectBuilder (javax.json.JsonObjectBuilder)1 Aggregate (org.dmg.pmml.AggregateDocument.Aggregate)1 CovariateList (org.dmg.pmml.CovariateListDocument.CovariateList)1 DataDictionary (org.dmg.pmml.DataDictionaryDocument.DataDictionary)1 DataField (org.dmg.pmml.DataFieldDocument.DataField)1 Discretize (org.dmg.pmml.DiscretizeDocument.Discretize)1 FactorList (org.dmg.pmml.FactorListDocument.FactorList)1 FieldColumnPair (org.dmg.pmml.FieldColumnPairDocument.FieldColumnPair)1 GeneralRegressionModel (org.dmg.pmml.GeneralRegressionModelDocument.GeneralRegressionModel)1 LocalTransformations (org.dmg.pmml.LocalTransformationsDocument.LocalTransformations)1 MapValues (org.dmg.pmml.MapValuesDocument.MapValues)1 MiningField (org.dmg.pmml.MiningFieldDocument.MiningField)1 MiningSchema (org.dmg.pmml.MiningSchemaDocument.MiningSchema)1 NormContinuous (org.dmg.pmml.NormContinuousDocument.NormContinuous)1