Search in sources :

Example 1 with GeneralRegressionModel

use of org.dmg.pmml.GeneralRegressionModelDocument.GeneralRegressionModel in project knime-core by knime.

the class PMMLPortObject method moveGlobalTransformationsToModel.

/**
 * Moves the content of the transformation dictionary to local
 * transformations of the model if a model exists.
 */
public void moveGlobalTransformationsToModel() {
    PMML pmml = m_pmmlDoc.getPMML();
    TransformationDictionary transDict = pmml.getTransformationDictionary();
    if (transDict == null || transDict.getDerivedFieldArray() == null || transDict.getDerivedFieldArray().length == 0) {
        // nothing to be moved
        return;
    }
    DerivedField[] globalDerivedFields = transDict.getDerivedFieldArray();
    LocalTransformations localTrans = null;
    if (pmml.getTreeModelArray().length > 0) {
        TreeModel model = pmml.getTreeModelArray(0);
        localTrans = model.getLocalTransformations();
        if (localTrans == null) {
            localTrans = model.addNewLocalTransformations();
        }
    } else if (pmml.getClusteringModelArray().length > 0) {
        ClusteringModel model = pmml.getClusteringModelArray(0);
        localTrans = model.getLocalTransformations();
        if (localTrans == null) {
            localTrans = model.addNewLocalTransformations();
        }
    } else if (pmml.getNeuralNetworkArray().length > 0) {
        NeuralNetwork model = pmml.getNeuralNetworkArray(0);
        localTrans = model.getLocalTransformations();
        if (localTrans == null) {
            localTrans = model.addNewLocalTransformations();
        }
    } else if (pmml.getSupportVectorMachineModelArray().length > 0) {
        SupportVectorMachineModel model = pmml.getSupportVectorMachineModelArray(0);
        localTrans = model.getLocalTransformations();
        if (localTrans == null) {
            localTrans = model.addNewLocalTransformations();
        }
    } else if (pmml.getRegressionModelArray().length > 0) {
        RegressionModel model = pmml.getRegressionModelArray(0);
        localTrans = model.getLocalTransformations();
        if (localTrans == null) {
            localTrans = model.addNewLocalTransformations();
        }
    } else if (pmml.getGeneralRegressionModelArray().length > 0) {
        GeneralRegressionModel model = pmml.getGeneralRegressionModelArray(0);
        localTrans = model.getLocalTransformations();
        if (localTrans == null) {
            localTrans = model.addNewLocalTransformations();
        }
    } else if (pmml.sizeOfRuleSetModelArray() > 0) {
        RuleSetModel model = pmml.getRuleSetModelArray(0);
        localTrans = model.getLocalTransformations();
        if (localTrans == null) {
            localTrans = model.addNewLocalTransformations();
        }
    }
    if (localTrans != null) {
        DerivedField[] derivedFields = appendDerivedFields(localTrans.getDerivedFieldArray(), globalDerivedFields);
        localTrans.setDerivedFieldArray(derivedFields);
        // remove derived fields from TransformationDictionary
        transDict.setDerivedFieldArray(new DerivedField[0]);
    }
// else do nothing as no model exists yet
}
Also used : TreeModel(org.dmg.pmml.TreeModelDocument.TreeModel) RuleSetModel(org.dmg.pmml.RuleSetModelDocument.RuleSetModel) LocalTransformations(org.dmg.pmml.LocalTransformationsDocument.LocalTransformations) TransformationDictionary(org.dmg.pmml.TransformationDictionaryDocument.TransformationDictionary) GeneralRegressionModel(org.dmg.pmml.GeneralRegressionModelDocument.GeneralRegressionModel) PMML(org.dmg.pmml.PMMLDocument.PMML) NeuralNetwork(org.dmg.pmml.NeuralNetworkDocument.NeuralNetwork) SupportVectorMachineModel(org.dmg.pmml.SupportVectorMachineModelDocument.SupportVectorMachineModel) DerivedField(org.dmg.pmml.DerivedFieldDocument.DerivedField) ClusteringModel(org.dmg.pmml.ClusteringModelDocument.ClusteringModel) GeneralRegressionModel(org.dmg.pmml.GeneralRegressionModelDocument.GeneralRegressionModel) RegressionModel(org.dmg.pmml.RegressionModelDocument.RegressionModel)

Example 2 with GeneralRegressionModel

use of org.dmg.pmml.GeneralRegressionModelDocument.GeneralRegressionModel in project knime-core by knime.

the class PMMLModelWrapper method getSegmentContent.

/**
 * Returns the content of a segment as a model wrapper.
 * @param s The segment
 * @return Returns a wrapper around the model
 */
public static PMMLModelWrapper getSegmentContent(final Segment s) {
    TreeModel treemodel = s.getTreeModel();
    if (treemodel != null) {
        return new PMMLTreeModelWrapper(treemodel);
    }
    RegressionModel regrmodel = s.getRegressionModel();
    if (regrmodel != null) {
        return new PMMLRegressionModelWrapper(regrmodel);
    }
    GeneralRegressionModel genregrmodel = s.getGeneralRegressionModel();
    if (genregrmodel != null) {
        return new PMMLGeneralRegressionModelWrapper(genregrmodel);
    }
    ClusteringModel clustmodel = s.getClusteringModel();
    if (clustmodel != null) {
        return new PMMLClusteringModelWrapper(clustmodel);
    }
    NaiveBayesModel nbmodel = s.getNaiveBayesModel();
    if (nbmodel != null) {
        return new PMMLNaiveBayesModelWrapper(nbmodel);
    }
    NeuralNetwork nn = s.getNeuralNetwork();
    if (nn != null) {
        return new PMMLNeuralNetworkWrapper(nn);
    }
    RuleSetModel rsmodel = s.getRuleSetModel();
    if (rsmodel != null) {
        return new PMMLRuleSetModelWrapper(rsmodel);
    }
    SupportVectorMachineModel svmmodel = s.getSupportVectorMachineModel();
    if (svmmodel != null) {
        return new PMMLSupportVectorMachineModelWrapper(svmmodel);
    }
    return null;
}
Also used : RuleSetModel(org.dmg.pmml.RuleSetModelDocument.RuleSetModel) NaiveBayesModel(org.dmg.pmml.NaiveBayesModelDocument.NaiveBayesModel) NeuralNetwork(org.dmg.pmml.NeuralNetworkDocument.NeuralNetwork) RegressionModel(org.dmg.pmml.RegressionModelDocument.RegressionModel) GeneralRegressionModel(org.dmg.pmml.GeneralRegressionModelDocument.GeneralRegressionModel) TreeModel(org.dmg.pmml.TreeModelDocument.TreeModel) GeneralRegressionModel(org.dmg.pmml.GeneralRegressionModelDocument.GeneralRegressionModel) SupportVectorMachineModel(org.dmg.pmml.SupportVectorMachineModelDocument.SupportVectorMachineModel) ClusteringModel(org.dmg.pmml.ClusteringModelDocument.ClusteringModel)

Example 3 with GeneralRegressionModel

use of org.dmg.pmml.GeneralRegressionModelDocument.GeneralRegressionModel in project knime-core by knime.

the class PMMLGeneralRegressionTranslator method initializeFrom.

/**
 * {@inheritDoc}
 */
@Override
public void initializeFrom(final PMMLDocument pmmlDoc) {
    m_nameMapper = new DerivedFieldMapper(pmmlDoc);
    List<GeneralRegressionModel> models = pmmlDoc.getPMML().getGeneralRegressionModelList();
    if (models.isEmpty()) {
        throw new IllegalArgumentException("No general regression model" + " provided.");
    } else if (models.size() > 1) {
        LOGGER.warn("Multiple general regression models found. " + "Only the first model is considered.");
    }
    GeneralRegressionModel reg = models.get(0);
    // read the content type
    PMMLGeneralRegressionContent.ModelType modelType = getKNIMERegModelType(reg.getModelType());
    m_content.setModelType(modelType);
    // read the function name
    FunctionName functionName = getKNIMEFunctionName(reg.getFunctionName());
    m_content.setFunctionName(functionName);
    m_content.setAlgorithmName(reg.getAlgorithmName());
    m_content.setModelName(reg.getModelName());
    if (reg.getCumulativeLink() != null) {
        throw new IllegalArgumentException("The attribute \"cumulativeLink\"" + " is currently not supported.");
    }
    m_content.setTargetReferenceCategory(reg.getTargetReferenceCategory());
    if (reg.isSetOffsetValue()) {
        m_content.setOffsetValue(reg.getOffsetValue());
    }
    if (reg.getLocalTransformations() != null && reg.getLocalTransformations().getDerivedFieldList() != null) {
        updateVectorLengthsBasedOnDerivedFields(reg.getLocalTransformations().getDerivedFieldList());
    }
    // final Stream<String> vectorLengthsAsJsonAsString = reg.getMiningSchema().getExtensionList().stream()
    // .filter(e -> e.getExtender().equals(EXTENDER) && e.getName().equals(VECTOR_COLUMNS_WITH_LENGTH)).map(v -> v.getValue());
    // vectorLengthsAsJsonAsString
    // .forEachOrdered(jsonAsString -> m_content.updateVectorLengths(
    // Json.createReader(new StringReader(jsonAsString)).readObject().entrySet().stream().collect(
    // Collectors.toMap(Entry::getKey, entry -> ((JsonNumber)entry.getValue()).intValueExact()))));
    // read the parameter list
    ParameterList pmmlParamList = reg.getParameterList();
    if (pmmlParamList != null && pmmlParamList.sizeOfParameterArray() > 0) {
        List<Parameter> pmmlParam = pmmlParamList.getParameterList();
        PMMLParameter[] paramList = new PMMLParameter[pmmlParam.size()];
        for (int i = 0; i < pmmlParam.size(); i++) {
            String name = m_nameMapper.getColumnName(pmmlParam.get(i).getName());
            String label = pmmlParam.get(i).getLabel();
            if (label == null) {
                paramList[i] = new PMMLParameter(name);
            } else {
                paramList[i] = new PMMLParameter(name, label);
            }
        }
        m_content.setParameterList(paramList);
    } else {
        m_content.setParameterList(new PMMLParameter[0]);
    }
    // read the factor list
    FactorList pmmlFactorList = reg.getFactorList();
    if (pmmlFactorList != null && pmmlFactorList.sizeOfPredictorArray() > 0) {
        List<Predictor> pmmlPredictor = pmmlFactorList.getPredictorList();
        PMMLPredictor[] predictor = new PMMLPredictor[pmmlPredictor.size()];
        for (int i = 0; i < pmmlPredictor.size(); i++) {
            predictor[i] = new PMMLPredictor(m_nameMapper.getColumnName(pmmlPredictor.get(i).getName()));
        }
        m_content.setFactorList(predictor);
    } else {
        m_content.setFactorList(new PMMLPredictor[0]);
    }
    // read covariate list
    CovariateList covariateList = reg.getCovariateList();
    if (covariateList != null && covariateList.sizeOfPredictorArray() > 0) {
        List<Predictor> pmmlPredictor = covariateList.getPredictorList();
        PMMLPredictor[] predictor = new PMMLPredictor[pmmlPredictor.size()];
        for (int i = 0; i < pmmlPredictor.size(); i++) {
            predictor[i] = new PMMLPredictor(m_nameMapper.getColumnName(pmmlPredictor.get(i).getName()));
        }
        m_content.setCovariateList(predictor);
    } else {
        m_content.setCovariateList(new PMMLPredictor[0]);
    }
    // read PPMatrix
    PPMatrix ppMatrix = reg.getPPMatrix();
    if (ppMatrix != null && ppMatrix.sizeOfPPCellArray() > 0) {
        List<PPCell> pmmlCellArray = ppMatrix.getPPCellList();
        PMMLPPCell[] cells = new PMMLPPCell[pmmlCellArray.size()];
        for (int i = 0; i < pmmlCellArray.size(); i++) {
            PPCell ppCell = pmmlCellArray.get(i);
            cells[i] = new PMMLPPCell(ppCell.getValue(), m_nameMapper.getColumnName(ppCell.getPredictorName()), ppCell.getParameterName(), ppCell.getTargetCategory());
        }
        m_content.setPPMatrix(cells);
    } else {
        m_content.setPPMatrix(new PMMLPPCell[0]);
    }
    // read CovMatrix
    PCovMatrix pCovMatrix = reg.getPCovMatrix();
    if (pCovMatrix != null && pCovMatrix.sizeOfPCovCellArray() > 0) {
        List<PCovCell> pCovCellArray = pCovMatrix.getPCovCellList();
        PMMLPCovCell[] covCells = new PMMLPCovCell[pCovCellArray.size()];
        for (int i = 0; i < pCovCellArray.size(); i++) {
            PCovCell c = pCovCellArray.get(i);
            covCells[i] = new PMMLPCovCell(c.getPRow(), c.getPCol(), c.getTRow(), c.getTCol(), c.getValue(), c.getTargetCategory());
        }
        m_content.setPCovMatrix(covCells);
    } else {
        m_content.setPCovMatrix(new PMMLPCovCell[0]);
    }
    // read ParamMatrix
    ParamMatrix paramMatrix = reg.getParamMatrix();
    if (paramMatrix != null && paramMatrix.sizeOfPCellArray() > 0) {
        List<PCell> pCellArray = paramMatrix.getPCellList();
        PMMLPCell[] cells = new PMMLPCell[pCellArray.size()];
        for (int i = 0; i < pCellArray.size(); i++) {
            PCell p = pCellArray.get(i);
            double beta = p.getBeta();
            BigInteger df = p.getDf();
            if (df != null) {
                cells[i] = new PMMLPCell(p.getParameterName(), beta, df.intValue(), p.getTargetCategory());
            } else {
                cells[i] = new PMMLPCell(p.getParameterName(), beta, p.getTargetCategory());
            }
        }
        m_content.setParamMatrix(cells);
    } else {
        m_content.setParamMatrix(new PMMLPCell[0]);
    }
}
Also used : Predictor(org.dmg.pmml.PredictorDocument.Predictor) PPCell(org.dmg.pmml.PPCellDocument.PPCell) DerivedFieldMapper(org.knime.core.node.port.pmml.preproc.DerivedFieldMapper) FunctionName(org.knime.base.node.mine.regression.pmmlgreg.PMMLGeneralRegressionContent.FunctionName) FactorList(org.dmg.pmml.FactorListDocument.FactorList) PPCell(org.dmg.pmml.PPCellDocument.PPCell) PCell(org.dmg.pmml.PCellDocument.PCell) ParamMatrix(org.dmg.pmml.ParamMatrixDocument.ParamMatrix) PPMatrix(org.dmg.pmml.PPMatrixDocument.PPMatrix) CovariateList(org.dmg.pmml.CovariateListDocument.CovariateList) PCovMatrix(org.dmg.pmml.PCovMatrixDocument.PCovMatrix) PCovCell(org.dmg.pmml.PCovCellDocument.PCovCell) GeneralRegressionModel(org.dmg.pmml.GeneralRegressionModelDocument.GeneralRegressionModel) ParameterList(org.dmg.pmml.ParameterListDocument.ParameterList) Parameter(org.dmg.pmml.ParameterDocument.Parameter) BigInteger(java.math.BigInteger)

Example 4 with GeneralRegressionModel

use of org.dmg.pmml.GeneralRegressionModelDocument.GeneralRegressionModel in project knime-core by knime.

the class PMMLUtils method getFirstMiningSchema.

/**
 * Retrieves the mining schema of the first model of a specific type.
 *
 * @param pmmlDoc the PMML document to extract the mining schema from
 * @param type the type of the model
 * @return the mining schema of the first model of the given type or null if
 *         there is no model of the given type contained in the pmmlDoc
 */
public static MiningSchema getFirstMiningSchema(final PMMLDocument pmmlDoc, final SchemaType type) {
    Map<PMMLModelType, Integer> models = getNumberOfModels(pmmlDoc);
    if (!models.containsKey(PMMLModelType.getType(type))) {
        return null;
    }
    PMML pmml = pmmlDoc.getPMML();
    /*
         * Unfortunately the PMML models have no common base class. Therefore a
         * cast to the specific type is necessary for being able to add the
         * mining schema.
         */
    if (AssociationModel.type.equals(type)) {
        AssociationModel model = pmml.getAssociationModelArray(0);
        return model.getMiningSchema();
    } else if (ClusteringModel.type.equals(type)) {
        ClusteringModel model = pmml.getClusteringModelArray(0);
        return model.getMiningSchema();
    } else if (GeneralRegressionModel.type.equals(type)) {
        GeneralRegressionModel model = pmml.getGeneralRegressionModelArray(0);
        return model.getMiningSchema();
    } else if (MiningModel.type.equals(type)) {
        MiningModel model = pmml.getMiningModelArray(0);
        return model.getMiningSchema();
    } else if (NaiveBayesModel.type.equals(type)) {
        NaiveBayesModel model = pmml.getNaiveBayesModelArray(0);
        return model.getMiningSchema();
    } else if (NeuralNetwork.type.equals(type)) {
        NeuralNetwork model = pmml.getNeuralNetworkArray(0);
        return model.getMiningSchema();
    } else if (RegressionModel.type.equals(type)) {
        RegressionModel model = pmml.getRegressionModelArray(0);
        return model.getMiningSchema();
    } else if (RuleSetModel.type.equals(type)) {
        RuleSetModel model = pmml.getRuleSetModelArray(0);
        return model.getMiningSchema();
    } else if (SequenceModel.type.equals(type)) {
        SequenceModel model = pmml.getSequenceModelArray(0);
        return model.getMiningSchema();
    } else if (SupportVectorMachineModel.type.equals(type)) {
        SupportVectorMachineModel model = pmml.getSupportVectorMachineModelArray(0);
        return model.getMiningSchema();
    } else if (TextModel.type.equals(type)) {
        TextModel model = pmml.getTextModelArray(0);
        return model.getMiningSchema();
    } else if (TimeSeriesModel.type.equals(type)) {
        TimeSeriesModel model = pmml.getTimeSeriesModelArray(0);
        return model.getMiningSchema();
    } else if (TreeModel.type.equals(type)) {
        TreeModel model = pmml.getTreeModelArray(0);
        return model.getMiningSchema();
    } else {
        return null;
    }
}
Also used : RuleSetModel(org.dmg.pmml.RuleSetModelDocument.RuleSetModel) SequenceModel(org.dmg.pmml.SequenceModelDocument.SequenceModel) TextModel(org.dmg.pmml.TextModelDocument.TextModel) NaiveBayesModel(org.dmg.pmml.NaiveBayesModelDocument.NaiveBayesModel) TimeSeriesModel(org.dmg.pmml.TimeSeriesModelDocument.TimeSeriesModel) NeuralNetwork(org.dmg.pmml.NeuralNetworkDocument.NeuralNetwork) RegressionModel(org.dmg.pmml.RegressionModelDocument.RegressionModel) GeneralRegressionModel(org.dmg.pmml.GeneralRegressionModelDocument.GeneralRegressionModel) TreeModel(org.dmg.pmml.TreeModelDocument.TreeModel) MiningModel(org.dmg.pmml.MiningModelDocument.MiningModel) GeneralRegressionModel(org.dmg.pmml.GeneralRegressionModelDocument.GeneralRegressionModel) PMML(org.dmg.pmml.PMMLDocument.PMML) SupportVectorMachineModel(org.dmg.pmml.SupportVectorMachineModelDocument.SupportVectorMachineModel) AssociationModel(org.dmg.pmml.AssociationModelDocument.AssociationModel) ClusteringModel(org.dmg.pmml.ClusteringModelDocument.ClusteringModel)

Example 5 with GeneralRegressionModel

use of org.dmg.pmml.GeneralRegressionModelDocument.GeneralRegressionModel in project knime-core by knime.

the class PMMLGeneralRegressionTranslator method exportTo.

/**
 * {@inheritDoc}
 */
@Override
public SchemaType exportTo(final PMMLDocument pmmlDoc, final PMMLPortObjectSpec spec) {
    m_nameMapper = new DerivedFieldMapper(pmmlDoc);
    GeneralRegressionModel reg = pmmlDoc.getPMML().addNewGeneralRegressionModel();
    final JsonObjectBuilder jsonBuilder = Json.createObjectBuilder();
    if (!m_content.getVectorLengths().isEmpty()) {
        LocalTransformations localTransformations = reg.addNewLocalTransformations();
        for (final Entry<? extends String, ? extends Integer> entry : m_content.getVectorLengths().entrySet()) {
            DataColumnSpec columnSpec = spec.getDataTableSpec().getColumnSpec(entry.getKey());
            if (columnSpec != null) {
                final DataType type = columnSpec.getType();
                final DataColumnProperties props = columnSpec.getProperties();
                final boolean bitVector = type.isCompatible(BitVectorValue.class) || (type.isCompatible(StringValue.class) && props.containsProperty("realType") && "BitVector".equals(props.getProperty("realType")));
                final boolean byteVector = type.isCompatible(ByteVectorValue.class) || (type.isCompatible(StringValue.class) && props.containsProperty("realType") && "ByteVector".equals(props.getProperty("realType")));
                final String lengthAsString;
                final int width;
                if (byteVector) {
                    lengthAsString = "3";
                    width = 4;
                } else if (bitVector) {
                    lengthAsString = "1";
                    width = 1;
                } else {
                    throw new UnsupportedOperationException("Not supported type: " + type + " for column: " + columnSpec);
                }
                for (int i = 0; i < entry.getValue().intValue(); ++i) {
                    final DerivedField derivedField = localTransformations.addNewDerivedField();
                    derivedField.setOptype(OPTYPE.CONTINUOUS);
                    derivedField.setDataType(DATATYPE.INTEGER);
                    derivedField.setName(entry.getKey() + "[" + i + "]");
                    Apply apply = derivedField.addNewApply();
                    apply.setFunction("substring");
                    apply.addNewFieldRef().setField(entry.getKey());
                    Constant from = apply.addNewConstant();
                    from.setDataType(DATATYPE.INTEGER);
                    from.setStringValue(bitVector ? Long.toString(entry.getValue().longValue() - i) : Long.toString(i * width + 1L));
                    Constant length = apply.addNewConstant();
                    length.setDataType(DATATYPE.INTEGER);
                    length.setStringValue(lengthAsString);
                }
            }
            jsonBuilder.add(entry.getKey(), entry.getValue().intValue());
        }
    }
    // PMMLPortObjectSpecCreator newSpecCreator = new PMMLPortObjectSpecCreator(spec);
    // newSpecCreator.addPreprocColNames(m_content.getVectorLengths().entrySet().stream()
    // .flatMap(
    // e -> IntStream.iterate(0, o -> o + 1).limit(e.getValue()).mapToObj(i -> e.getKey() + "[" + i + "]"))
    // .collect(Collectors.toList()));
    PMMLMiningSchemaTranslator.writeMiningSchema(spec, reg);
    // if (!m_content.getVectorLengths().isEmpty()) {
    // Extension miningExtension = reg.getMiningSchema().addNewExtension();
    // miningExtension.setExtender(EXTENDER);
    // miningExtension.setName(VECTOR_COLUMNS_WITH_LENGTH);
    // miningExtension.setValue(jsonBuilder.build().toString());
    // }
    reg.setModelType(getPMMLRegModelType(m_content.getModelType()));
    reg.setFunctionName(getPMMLMiningFunction(m_content.getFunctionName()));
    String algorithmName = m_content.getAlgorithmName();
    if (algorithmName != null && !algorithmName.isEmpty()) {
        reg.setAlgorithmName(algorithmName);
    }
    String modelName = m_content.getModelName();
    if (modelName != null && !modelName.isEmpty()) {
        reg.setModelName(modelName);
    }
    String targetReferenceCategory = m_content.getTargetReferenceCategory();
    if (targetReferenceCategory != null && !targetReferenceCategory.isEmpty()) {
        reg.setTargetReferenceCategory(targetReferenceCategory);
    }
    if (m_content.getOffsetValue() != null) {
        reg.setOffsetValue(m_content.getOffsetValue());
    }
    // add parameter list
    ParameterList paramList = reg.addNewParameterList();
    for (PMMLParameter p : m_content.getParameterList()) {
        Parameter param = paramList.addNewParameter();
        param.setName(p.getName());
        String label = p.getLabel();
        if (label != null) {
            param.setLabel(m_nameMapper.getDerivedFieldName(label));
        }
    }
    // add factor list
    FactorList factorList = reg.addNewFactorList();
    for (PMMLPredictor p : m_content.getFactorList()) {
        Predictor predictor = factorList.addNewPredictor();
        predictor.setName(m_nameMapper.getDerivedFieldName(p.getName()));
    }
    // add covariate list
    CovariateList covariateList = reg.addNewCovariateList();
    for (PMMLPredictor p : m_content.getCovariateList()) {
        Predictor predictor = covariateList.addNewPredictor();
        predictor.setName(m_nameMapper.getDerivedFieldName(p.getName()));
    }
    // add PPMatrix
    PPMatrix ppMatrix = reg.addNewPPMatrix();
    for (PMMLPPCell p : m_content.getPPMatrix()) {
        PPCell cell = ppMatrix.addNewPPCell();
        cell.setValue(p.getValue());
        cell.setPredictorName(m_nameMapper.getDerivedFieldName(p.getPredictorName()));
        cell.setParameterName(p.getParameterName());
        String targetCategory = p.getTargetCategory();
        if (targetCategory != null && !targetCategory.isEmpty()) {
            cell.setTargetCategory(targetCategory);
        }
    }
    // add CovMatrix
    if (m_content.getPCovMatrix().length > 0) {
        PCovMatrix pCovMatrix = reg.addNewPCovMatrix();
        for (PMMLPCovCell p : m_content.getPCovMatrix()) {
            PCovCell covCell = pCovMatrix.addNewPCovCell();
            covCell.setPRow(p.getPRow());
            covCell.setPCol(p.getPCol());
            String tCol = p.getTCol();
            String tRow = p.getTRow();
            if (tRow != null || tCol != null) {
                covCell.setTRow(tRow);
                covCell.setTCol(tCol);
            }
            covCell.setValue(p.getValue());
            String targetCategory = p.getTargetCategory();
            if (targetCategory != null && !targetCategory.isEmpty()) {
                covCell.setTargetCategory(targetCategory);
            }
        }
    }
    // add ParamMatrix
    ParamMatrix paramMatrix = reg.addNewParamMatrix();
    for (PMMLPCell p : m_content.getParamMatrix()) {
        PCell pCell = paramMatrix.addNewPCell();
        String targetCategory = p.getTargetCategory();
        if (targetCategory != null) {
            pCell.setTargetCategory(targetCategory);
        }
        pCell.setParameterName(p.getParameterName());
        pCell.setBeta(p.getBeta());
        Integer df = p.getDf();
        if (df != null) {
            pCell.setDf(BigInteger.valueOf(df));
        }
    }
    return GeneralRegressionModel.type;
}
Also used : Predictor(org.dmg.pmml.PredictorDocument.Predictor) Apply(org.dmg.pmml.ApplyDocument.Apply) Constant(org.dmg.pmml.ConstantDocument.Constant) PPCell(org.dmg.pmml.PPCellDocument.PPCell) ByteVectorValue(org.knime.core.data.vector.bytevector.ByteVectorValue) DerivedFieldMapper(org.knime.core.node.port.pmml.preproc.DerivedFieldMapper) DataColumnSpec(org.knime.core.data.DataColumnSpec) FactorList(org.dmg.pmml.FactorListDocument.FactorList) PPCell(org.dmg.pmml.PPCellDocument.PPCell) PCell(org.dmg.pmml.PCellDocument.PCell) DataType(org.knime.core.data.DataType) JsonObjectBuilder(javax.json.JsonObjectBuilder) DataColumnProperties(org.knime.core.data.DataColumnProperties) ParamMatrix(org.dmg.pmml.ParamMatrixDocument.ParamMatrix) PPMatrix(org.dmg.pmml.PPMatrixDocument.PPMatrix) CovariateList(org.dmg.pmml.CovariateListDocument.CovariateList) PCovMatrix(org.dmg.pmml.PCovMatrixDocument.PCovMatrix) BigInteger(java.math.BigInteger) LocalTransformations(org.dmg.pmml.LocalTransformationsDocument.LocalTransformations) PCovCell(org.dmg.pmml.PCovCellDocument.PCovCell) GeneralRegressionModel(org.dmg.pmml.GeneralRegressionModelDocument.GeneralRegressionModel) ParameterList(org.dmg.pmml.ParameterListDocument.ParameterList) Parameter(org.dmg.pmml.ParameterDocument.Parameter) BitVectorValue(org.knime.core.data.vector.bitvector.BitVectorValue) DerivedField(org.dmg.pmml.DerivedFieldDocument.DerivedField)

Aggregations

GeneralRegressionModel (org.dmg.pmml.GeneralRegressionModelDocument.GeneralRegressionModel)7 RegressionModel (org.dmg.pmml.RegressionModelDocument.RegressionModel)5 TreeModel (org.dmg.pmml.TreeModelDocument.TreeModel)5 ClusteringModel (org.dmg.pmml.ClusteringModelDocument.ClusteringModel)4 NaiveBayesModel (org.dmg.pmml.NaiveBayesModelDocument.NaiveBayesModel)4 NeuralNetwork (org.dmg.pmml.NeuralNetworkDocument.NeuralNetwork)4 RuleSetModel (org.dmg.pmml.RuleSetModelDocument.RuleSetModel)4 SupportVectorMachineModel (org.dmg.pmml.SupportVectorMachineModelDocument.SupportVectorMachineModel)4 AssociationModel (org.dmg.pmml.AssociationModelDocument.AssociationModel)3 LocalTransformations (org.dmg.pmml.LocalTransformationsDocument.LocalTransformations)3 PMML (org.dmg.pmml.PMMLDocument.PMML)3 SequenceModel (org.dmg.pmml.SequenceModelDocument.SequenceModel)3 TextModel (org.dmg.pmml.TextModelDocument.TextModel)3 BigInteger (java.math.BigInteger)2 CovariateList (org.dmg.pmml.CovariateListDocument.CovariateList)2 DerivedField (org.dmg.pmml.DerivedFieldDocument.DerivedField)2 FactorList (org.dmg.pmml.FactorListDocument.FactorList)2 MiningModel (org.dmg.pmml.MiningModelDocument.MiningModel)2 PCell (org.dmg.pmml.PCellDocument.PCell)2 PCovCell (org.dmg.pmml.PCovCellDocument.PCovCell)2