Search in sources :

Example 6 with DerivedFieldMapper

use of org.knime.core.node.port.pmml.preproc.DerivedFieldMapper in project knime-core by knime.

the class PMMLPortObject method addGlobalTransformations.

/**
 * Adds global transformations to the PMML document. Only DerivedField
 * elements are supported so far. If no global transformations are set so
 * far the dictionary is set as new transformation dictionary, otherwise
 * all contained transformations are appended to the existing one.
 *
 * @param dictionary the transformation dictionary that contains the
 *      transformations to be added
 */
public void addGlobalTransformations(final TransformationDictionary dictionary) {
    // add the transformations to the TransformationDictionary
    if (dictionary.getDefineFunctionArray().length > 0) {
        throw new IllegalArgumentException("DefineFunctions are not " + "supported so far. Only derived fields are allowed.");
    }
    TransformationDictionary dict = m_pmmlDoc.getPMML().getTransformationDictionary();
    if (dict == null) {
        m_pmmlDoc.getPMML().setTransformationDictionary(dictionary);
        dict = m_pmmlDoc.getPMML().getTransformationDictionary();
    } else {
        // append the transformations to the existing dictionary
        DerivedField[] existingFields = dict.getDerivedFieldArray();
        DerivedField[] result = appendDerivedFields(existingFields, dictionary.getDerivedFieldArray());
        dict.setDerivedFieldArray(result);
    }
    DerivedField[] df = dict.getDerivedFieldArray();
    List<String> colNames = new ArrayList<String>(df.length);
    Set<String> dfNames = new HashSet<String>();
    for (int i = 0; i < df.length; i++) {
        String derivedName = df[i].getName();
        if (dfNames.contains(derivedName)) {
            throw new IllegalArgumentException("Derived field name \"" + derivedName + "\" is not unique.");
        }
        dfNames.add(derivedName);
        String displayName = df[i].getDisplayName();
        colNames.add(displayName == null ? derivedName : displayName);
    }
    /* Remove data fields from data dictionary that where created as a
         * derived field. In KNIME the origin of columns is not distinguished
         * and all columns are added to the data dictionary. But in PMML this
         * results in duplicate entries. Those columns should only appear once
         * as derived field in the transformation dictionary or local
         * transformations. */
    DataDictionary dataDict = m_pmmlDoc.getPMML().getDataDictionary();
    DataField[] dataFieldArray = dataDict.getDataFieldArray();
    List<DataField> dataFields = new ArrayList<DataField>(Arrays.asList(dataFieldArray));
    for (DataField dataField : dataFieldArray) {
        if (dfNames.contains(dataField.getName())) {
            dataFields.remove(dataField);
        }
    }
    dataDict.setDataFieldArray(dataFields.toArray(new DataField[0]));
    // update the number of fields
    dataDict.setNumberOfFields(BigInteger.valueOf(dataFields.size()));
    // -------------------------------------------------
    // update field names in the model if applicable
    DerivedFieldMapper dfm = new DerivedFieldMapper(df);
    Map<String, String> derivedFieldMap = dfm.getDerivedFieldMap();
    /* Use XPATH to update field names in the model and move the derived
         * fields to local transformations. */
    PMML pmml = m_pmmlDoc.getPMML();
    if (pmml.getTreeModelArray().length > 0) {
        fixAttributeAtPath(pmml, TREE_PATH, FIELD, derivedFieldMap);
    } else if (pmml.getClusteringModelArray().length > 0) {
        fixAttributeAtPath(pmml, CLUSTERING_PATH, FIELD, derivedFieldMap);
    } else if (pmml.getNeuralNetworkArray().length > 0) {
        fixAttributeAtPath(pmml, NN_PATH, FIELD, derivedFieldMap);
    } else if (pmml.getSupportVectorMachineModelArray().length > 0) {
        fixAttributeAtPath(pmml, SVM_PATH, FIELD, derivedFieldMap);
    } else if (pmml.getRegressionModelArray().length > 0) {
        fixAttributeAtPath(pmml, REGRESSION_PATH_1, FIELD, derivedFieldMap);
        fixAttributeAtPath(pmml, REGRESSION_PATH_2, NAME, derivedFieldMap);
    } else if (pmml.getGeneralRegressionModelArray().length > 0) {
        fixAttributeAtPath(pmml, GR_PATH_1, NAME, derivedFieldMap);
        fixAttributeAtPath(pmml, GR_PATH_2, LABEL, derivedFieldMap);
        fixAttributeAtPath(pmml, GR_PATH_3, PREDICTOR_NAME, derivedFieldMap);
    }
    // else do nothing as no model exists yet
    // --------------------------------------------------
    PMMLPortObjectSpecCreator creator = new PMMLPortObjectSpecCreator(this, m_spec.getDataTableSpec());
    creator.addPreprocColNames(colNames);
    m_spec = creator.createSpec();
}
Also used : TransformationDictionary(org.dmg.pmml.TransformationDictionaryDocument.TransformationDictionary) ArrayList(java.util.ArrayList) DataDictionary(org.dmg.pmml.DataDictionaryDocument.DataDictionary) DerivedFieldMapper(org.knime.core.node.port.pmml.preproc.DerivedFieldMapper) DataField(org.dmg.pmml.DataFieldDocument.DataField) PMML(org.dmg.pmml.PMMLDocument.PMML) DerivedField(org.dmg.pmml.DerivedFieldDocument.DerivedField) HashSet(java.util.HashSet) LinkedHashSet(java.util.LinkedHashSet)

Example 7 with DerivedFieldMapper

use of org.knime.core.node.port.pmml.preproc.DerivedFieldMapper in project knime-core by knime.

the class PMMLGeneralRegressionTranslator method initializeFrom.

/**
 * {@inheritDoc}
 */
@Override
public void initializeFrom(final PMMLDocument pmmlDoc) {
    m_nameMapper = new DerivedFieldMapper(pmmlDoc);
    List<GeneralRegressionModel> models = pmmlDoc.getPMML().getGeneralRegressionModelList();
    if (models.isEmpty()) {
        throw new IllegalArgumentException("No general regression model" + " provided.");
    } else if (models.size() > 1) {
        LOGGER.warn("Multiple general regression models found. " + "Only the first model is considered.");
    }
    GeneralRegressionModel reg = models.get(0);
    // read the content type
    PMMLGeneralRegressionContent.ModelType modelType = getKNIMERegModelType(reg.getModelType());
    m_content.setModelType(modelType);
    // read the function name
    FunctionName functionName = getKNIMEFunctionName(reg.getFunctionName());
    m_content.setFunctionName(functionName);
    m_content.setAlgorithmName(reg.getAlgorithmName());
    m_content.setModelName(reg.getModelName());
    if (reg.getCumulativeLink() != null) {
        throw new IllegalArgumentException("The attribute \"cumulativeLink\"" + " is currently not supported.");
    }
    m_content.setTargetReferenceCategory(reg.getTargetReferenceCategory());
    if (reg.isSetOffsetValue()) {
        m_content.setOffsetValue(reg.getOffsetValue());
    }
    if (reg.getLocalTransformations() != null && reg.getLocalTransformations().getDerivedFieldList() != null) {
        updateVectorLengthsBasedOnDerivedFields(reg.getLocalTransformations().getDerivedFieldList());
    }
    // final Stream<String> vectorLengthsAsJsonAsString = reg.getMiningSchema().getExtensionList().stream()
    // .filter(e -> e.getExtender().equals(EXTENDER) && e.getName().equals(VECTOR_COLUMNS_WITH_LENGTH)).map(v -> v.getValue());
    // vectorLengthsAsJsonAsString
    // .forEachOrdered(jsonAsString -> m_content.updateVectorLengths(
    // Json.createReader(new StringReader(jsonAsString)).readObject().entrySet().stream().collect(
    // Collectors.toMap(Entry::getKey, entry -> ((JsonNumber)entry.getValue()).intValueExact()))));
    // read the parameter list
    ParameterList pmmlParamList = reg.getParameterList();
    if (pmmlParamList != null && pmmlParamList.sizeOfParameterArray() > 0) {
        List<Parameter> pmmlParam = pmmlParamList.getParameterList();
        PMMLParameter[] paramList = new PMMLParameter[pmmlParam.size()];
        for (int i = 0; i < pmmlParam.size(); i++) {
            String name = m_nameMapper.getColumnName(pmmlParam.get(i).getName());
            String label = pmmlParam.get(i).getLabel();
            if (label == null) {
                paramList[i] = new PMMLParameter(name);
            } else {
                paramList[i] = new PMMLParameter(name, label);
            }
        }
        m_content.setParameterList(paramList);
    } else {
        m_content.setParameterList(new PMMLParameter[0]);
    }
    // read the factor list
    FactorList pmmlFactorList = reg.getFactorList();
    if (pmmlFactorList != null && pmmlFactorList.sizeOfPredictorArray() > 0) {
        List<Predictor> pmmlPredictor = pmmlFactorList.getPredictorList();
        PMMLPredictor[] predictor = new PMMLPredictor[pmmlPredictor.size()];
        for (int i = 0; i < pmmlPredictor.size(); i++) {
            predictor[i] = new PMMLPredictor(m_nameMapper.getColumnName(pmmlPredictor.get(i).getName()));
        }
        m_content.setFactorList(predictor);
    } else {
        m_content.setFactorList(new PMMLPredictor[0]);
    }
    // read covariate list
    CovariateList covariateList = reg.getCovariateList();
    if (covariateList != null && covariateList.sizeOfPredictorArray() > 0) {
        List<Predictor> pmmlPredictor = covariateList.getPredictorList();
        PMMLPredictor[] predictor = new PMMLPredictor[pmmlPredictor.size()];
        for (int i = 0; i < pmmlPredictor.size(); i++) {
            predictor[i] = new PMMLPredictor(m_nameMapper.getColumnName(pmmlPredictor.get(i).getName()));
        }
        m_content.setCovariateList(predictor);
    } else {
        m_content.setCovariateList(new PMMLPredictor[0]);
    }
    // read PPMatrix
    PPMatrix ppMatrix = reg.getPPMatrix();
    if (ppMatrix != null && ppMatrix.sizeOfPPCellArray() > 0) {
        List<PPCell> pmmlCellArray = ppMatrix.getPPCellList();
        PMMLPPCell[] cells = new PMMLPPCell[pmmlCellArray.size()];
        for (int i = 0; i < pmmlCellArray.size(); i++) {
            PPCell ppCell = pmmlCellArray.get(i);
            cells[i] = new PMMLPPCell(ppCell.getValue(), m_nameMapper.getColumnName(ppCell.getPredictorName()), ppCell.getParameterName(), ppCell.getTargetCategory());
        }
        m_content.setPPMatrix(cells);
    } else {
        m_content.setPPMatrix(new PMMLPPCell[0]);
    }
    // read CovMatrix
    PCovMatrix pCovMatrix = reg.getPCovMatrix();
    if (pCovMatrix != null && pCovMatrix.sizeOfPCovCellArray() > 0) {
        List<PCovCell> pCovCellArray = pCovMatrix.getPCovCellList();
        PMMLPCovCell[] covCells = new PMMLPCovCell[pCovCellArray.size()];
        for (int i = 0; i < pCovCellArray.size(); i++) {
            PCovCell c = pCovCellArray.get(i);
            covCells[i] = new PMMLPCovCell(c.getPRow(), c.getPCol(), c.getTRow(), c.getTCol(), c.getValue(), c.getTargetCategory());
        }
        m_content.setPCovMatrix(covCells);
    } else {
        m_content.setPCovMatrix(new PMMLPCovCell[0]);
    }
    // read ParamMatrix
    ParamMatrix paramMatrix = reg.getParamMatrix();
    if (paramMatrix != null && paramMatrix.sizeOfPCellArray() > 0) {
        List<PCell> pCellArray = paramMatrix.getPCellList();
        PMMLPCell[] cells = new PMMLPCell[pCellArray.size()];
        for (int i = 0; i < pCellArray.size(); i++) {
            PCell p = pCellArray.get(i);
            double beta = p.getBeta();
            BigInteger df = p.getDf();
            if (df != null) {
                cells[i] = new PMMLPCell(p.getParameterName(), beta, df.intValue(), p.getTargetCategory());
            } else {
                cells[i] = new PMMLPCell(p.getParameterName(), beta, p.getTargetCategory());
            }
        }
        m_content.setParamMatrix(cells);
    } else {
        m_content.setParamMatrix(new PMMLPCell[0]);
    }
}
Also used : Predictor(org.dmg.pmml.PredictorDocument.Predictor) PPCell(org.dmg.pmml.PPCellDocument.PPCell) DerivedFieldMapper(org.knime.core.node.port.pmml.preproc.DerivedFieldMapper) FunctionName(org.knime.base.node.mine.regression.pmmlgreg.PMMLGeneralRegressionContent.FunctionName) FactorList(org.dmg.pmml.FactorListDocument.FactorList) PPCell(org.dmg.pmml.PPCellDocument.PPCell) PCell(org.dmg.pmml.PCellDocument.PCell) ParamMatrix(org.dmg.pmml.ParamMatrixDocument.ParamMatrix) PPMatrix(org.dmg.pmml.PPMatrixDocument.PPMatrix) CovariateList(org.dmg.pmml.CovariateListDocument.CovariateList) PCovMatrix(org.dmg.pmml.PCovMatrixDocument.PCovMatrix) PCovCell(org.dmg.pmml.PCovCellDocument.PCovCell) GeneralRegressionModel(org.dmg.pmml.GeneralRegressionModelDocument.GeneralRegressionModel) ParameterList(org.dmg.pmml.ParameterListDocument.ParameterList) Parameter(org.dmg.pmml.ParameterDocument.Parameter) BigInteger(java.math.BigInteger)

Example 8 with DerivedFieldMapper

use of org.knime.core.node.port.pmml.preproc.DerivedFieldMapper in project knime-core by knime.

the class PMMLClusterTranslator method exportTo.

/**
 * {@inheritDoc}
 */
@Override
public SchemaType exportTo(final PMMLDocument pmmlDoc, final PMMLPortObjectSpec spec) {
    DerivedFieldMapper mapper = new DerivedFieldMapper(pmmlDoc);
    PMML pmml = pmmlDoc.getPMML();
    ClusteringModelDocument.ClusteringModel clusteringModel = pmml.addNewClusteringModel();
    PMMLMiningSchemaTranslator.writeMiningSchema(spec, clusteringModel);
    // ---------------------------------------------------
    // set clustering model attributes
    clusteringModel.setModelName("k-means");
    clusteringModel.setFunctionName(MININGFUNCTION.CLUSTERING);
    clusteringModel.setModelClass(ModelClass.CENTER_BASED);
    clusteringModel.setNumberOfClusters(BigInteger.valueOf(m_nrOfClusters));
    // ---------------------------------------------------
    // set comparison measure
    ComparisonMeasureDocument.ComparisonMeasure pmmlComparisonMeasure = clusteringModel.addNewComparisonMeasure();
    pmmlComparisonMeasure.setKind(Kind.DISTANCE);
    if (ComparisonMeasure.squaredEuclidean.equals(m_measure)) {
        pmmlComparisonMeasure.addNewSquaredEuclidean();
    } else {
        pmmlComparisonMeasure.addNewEuclidean();
    }
    // set clustering fields
    for (String colName : m_usedColumns) {
        ClusteringFieldDocument.ClusteringField pmmlClusteringField = clusteringModel.addNewClusteringField();
        pmmlClusteringField.setField(mapper.getDerivedFieldName(colName));
        pmmlClusteringField.setCompareFunction(COMPAREFUNCTION.ABS_DIFF);
    }
    // ----------------------------------------------------
    // set clusters
    int i = 0;
    for (double[] prototype : m_prototypes) {
        ClusterDocument.Cluster pmmlCluster = clusteringModel.addNewCluster();
        String name = CLUSTER_NAME_PREFIX + i;
        pmmlCluster.setName(name);
        if (m_clusterCoverage != null && m_clusterCoverage.length == m_prototypes.length) {
            pmmlCluster.setSize(BigInteger.valueOf(m_clusterCoverage[i]));
        }
        i++;
        ArrayType pmmlArray = pmmlCluster.addNewArray();
        pmmlArray.setN(BigInteger.valueOf(prototype.length));
        pmmlArray.setType(Type.REAL);
        StringBuffer buff = new StringBuffer();
        for (double d : prototype) {
            buff.append(d + " ");
        }
        XmlCursor xmlCursor = pmmlArray.newCursor();
        xmlCursor.setTextValue(buff.toString());
        xmlCursor.dispose();
    }
    return ClusteringModel.type;
}
Also used : ClusteringModel(org.dmg.pmml.ClusteringModelDocument.ClusteringModel) ClusteringModelDocument(org.dmg.pmml.ClusteringModelDocument) ComparisonMeasureDocument(org.dmg.pmml.ComparisonMeasureDocument) ClusterDocument(org.dmg.pmml.ClusterDocument) XmlCursor(org.apache.xmlbeans.XmlCursor) ArrayType(org.dmg.pmml.ArrayType) DerivedFieldMapper(org.knime.core.node.port.pmml.preproc.DerivedFieldMapper) ClusteringFieldDocument(org.dmg.pmml.ClusteringFieldDocument) PMML(org.dmg.pmml.PMMLDocument.PMML) ClusteringField(org.dmg.pmml.ClusteringFieldDocument.ClusteringField)

Example 9 with DerivedFieldMapper

use of org.knime.core.node.port.pmml.preproc.DerivedFieldMapper in project knime-core by knime.

the class PMMLClusterTranslator method initializeFrom.

/**
 * {@inheritDoc}
 */
@Override
public void initializeFrom(final PMMLDocument pmmlDoc) {
    PMML pmml = pmmlDoc.getPMML();
    DerivedFieldMapper mapper = new DerivedFieldMapper(pmmlDoc);
    ClusteringModelDocument.ClusteringModel pmmlClusteringModel = pmml.getClusteringModelArray(0);
    // initialize ClusteringFields
    for (ClusteringField cf : pmmlClusteringModel.getClusteringFieldArray()) {
        m_usedColumns.add(mapper.getColumnName(cf.getField()));
        if (COMPAREFUNCTION.ABS_DIFF != cf.getCompareFunction()) {
            LOGGER.error("Comparison Function " + cf.getCompareFunction().toString() + " is not supported!");
            throw new IllegalArgumentException("Only the absolute difference (\"absDiff\") as " + "compare function is supported!");
        }
    }
    // ---------------------------------------------------
    // initialize Clusters
    m_nrOfClusters = pmmlClusteringModel.sizeOfClusterArray();
    m_prototypes = new double[m_nrOfClusters][m_usedColumns.size()];
    m_labels = new String[m_nrOfClusters];
    m_clusterCoverage = new int[m_nrOfClusters];
    for (int i = 0; i < m_nrOfClusters; i++) {
        ClusterDocument.Cluster currentCluster = pmmlClusteringModel.getClusterArray(i);
        m_labels[i] = currentCluster.getName();
        // in KNIME learner: m_labels[i] = "cluster_" + i;
        ArrayType clusterArray = currentCluster.getArray();
        String content = clusterArray.newCursor().getTextValue();
        String[] stringValues;
        content = content.trim();
        if (content.contains(DOUBLE_QUOT)) {
            content = content.replace(BACKSLASH + DOUBLE_QUOT, TAB);
            /* TODO We need to take care of the cases with double quots,
                 * e.g ==> <Array n="3" type="string">"Cheval  Blanc" "TABTAB"
                 "Latour"</Array> */
            stringValues = content.split(DOUBLE_QUOT + SPACE);
            for (int j = 0; j < stringValues.length; j++) {
                stringValues[j] = stringValues[j].replace(DOUBLE_QUOT, "");
                stringValues[j] = stringValues[j].replace(TAB, DOUBLE_QUOT);
                stringValues[j] = stringValues[j].trim();
            }
        } else {
            stringValues = content.split("\\s+");
        }
        for (int j = 0; j < m_usedColumns.size(); j++) {
            m_prototypes[i][j] = Double.valueOf(stringValues[j]);
        }
        if (currentCluster.isSetSize()) {
            m_clusterCoverage[i] = currentCluster.getSize().intValue();
        }
    }
    if (pmmlClusteringModel.isSetMissingValueWeights()) {
        ArrayType weights = pmmlClusteringModel.getMissingValueWeights().getArray();
        String content = weights.newCursor().getTextValue();
        String[] stringValues;
        Double[] weightValues;
        content = content.trim();
        if (content.contains(DOUBLE_QUOT)) {
            content = content.replace(BACKSLASH + DOUBLE_QUOT, TAB);
            /* TODO We need to take care of the cases with double quots,
                 * e.g ==> <Array n="3" type="string">"Cheval  Blanc" "TABTAB"
                 "Latour"</Array> */
            stringValues = content.split(DOUBLE_QUOT + SPACE);
            weightValues = new Double[stringValues.length];
            for (int j = 0; j < stringValues.length; j++) {
                stringValues[j] = stringValues[j].replace(DOUBLE_QUOT, "");
                stringValues[j] = stringValues[j].replace(TAB, DOUBLE_QUOT);
                stringValues[j] = stringValues[j].trim();
                weightValues[j] = Double.valueOf(stringValues[j]);
                if (weightValues[j] == null || weightValues[j].doubleValue() != 1.0) {
                    String msg = "Missing Value Weight not equals one" + " is not supported!";
                    LOGGER.error(msg);
                }
            }
        } else {
            stringValues = content.split("\\s+");
        }
    }
    // ------------------------------------------
    // initialize m_usedColumns from ClusteringField
    ClusteringFieldDocument.ClusteringField[] clusteringFieldArray = pmmlClusteringModel.getClusteringFieldArray();
    for (ClusteringField cf : clusteringFieldArray) {
        m_usedColumns.add(mapper.getColumnName(cf.getField()));
    }
    // --------------------------------------------
    // initialize Comparison Measure
    ComparisonMeasureDocument.ComparisonMeasure pmmlComparisonMeasure = pmmlClusteringModel.getComparisonMeasure();
    if (pmmlComparisonMeasure.isSetSquaredEuclidean()) {
        m_measure = ComparisonMeasure.squaredEuclidean;
    } else if (pmmlComparisonMeasure.isSetEuclidean()) {
        m_measure = ComparisonMeasure.euclidean;
    } else {
        String measure = pmmlComparisonMeasure.getDomNode().getFirstChild().getNodeName();
        throw new IllegalArgumentException("\"" + ComparisonMeasure.euclidean + "\" and \"" + ComparisonMeasure.squaredEuclidean + "\" are the only supported comparison " + "measures! Found " + measure + ".");
    }
    if (Kind.SIMILARITY == pmmlComparisonMeasure.getKind()) {
        LOGGER.error("A Similarity Kind of Comparison Measure is not " + "supported!");
    }
}
Also used : ClusteringModel(org.dmg.pmml.ClusteringModelDocument.ClusteringModel) ClusteringModelDocument(org.dmg.pmml.ClusteringModelDocument) ClusteringField(org.dmg.pmml.ClusteringFieldDocument.ClusteringField) ComparisonMeasureDocument(org.dmg.pmml.ComparisonMeasureDocument) ClusterDocument(org.dmg.pmml.ClusterDocument) ArrayType(org.dmg.pmml.ArrayType) DerivedFieldMapper(org.knime.core.node.port.pmml.preproc.DerivedFieldMapper) PMML(org.dmg.pmml.PMMLDocument.PMML)

Example 10 with DerivedFieldMapper

use of org.knime.core.node.port.pmml.preproc.DerivedFieldMapper in project knime-core by knime.

the class PMMLDecisionTreeTranslator method exportTo.

/**
 * {@inheritDoc}
 */
@Override
public SchemaType exportTo(final PMMLDocument pmmlDoc, final PMMLPortObjectSpec spec) {
    m_nameMapper = new DerivedFieldMapper(pmmlDoc);
    PMML pmml = pmmlDoc.getPMML();
    TreeModelDocument.TreeModel treeModel = pmml.addNewTreeModel();
    PMMLMiningSchemaTranslator.writeMiningSchema(spec, treeModel);
    treeModel.setModelName("DecisionTree");
    if (m_isClassification) {
        treeModel.setFunctionName(MININGFUNCTION.CLASSIFICATION);
    } else {
        treeModel.setFunctionName(MININGFUNCTION.REGRESSION);
    }
    // set up splitCharacteristic
    if (treeIsMultisplit(m_tree.getRootNode())) {
        treeModel.setSplitCharacteristic(SplitCharacteristic.MULTI_SPLIT);
    } else {
        treeModel.setSplitCharacteristic(SplitCharacteristic.BINARY_SPLIT);
    }
    // ----------------------------------------------
    // set up missing value strategy
    PMMLMissingValueStrategy mvStrategy = m_tree.getMVStrategy() != null ? m_tree.getMVStrategy() : PMMLMissingValueStrategy.NONE;
    treeModel.setMissingValueStrategy(MV_STRATEGY_TO_PMML_MAP.get(mvStrategy));
    // -------------------------------------------------
    // set up no true child strategy
    PMMLNoTrueChildStrategy ntcStrategy = m_tree.getNTCStrategy();
    if (PMMLNoTrueChildStrategy.RETURN_LAST_PREDICTION.equals(ntcStrategy)) {
        treeModel.setNoTrueChildStrategy(NOTRUECHILDSTRATEGY.RETURN_LAST_PREDICTION);
    } else if (PMMLNoTrueChildStrategy.RETURN_NULL_PREDICTION.equals(ntcStrategy)) {
        treeModel.setNoTrueChildStrategy(NOTRUECHILDSTRATEGY.RETURN_NULL_PREDICTION);
    }
    // --------------------------------------------------
    // set up tree node
    NodeDocument.Node rootNode = treeModel.addNewNode();
    addTreeNode(rootNode, m_tree.getRootNode(), new DerivedFieldMapper(pmmlDoc));
    return TreeModel.type;
}
Also used : DerivedFieldMapper(org.knime.core.node.port.pmml.preproc.DerivedFieldMapper) Node(org.dmg.pmml.NodeDocument.Node) PMML(org.dmg.pmml.PMMLDocument.PMML) DecisionTreeNodeSplitPMML(org.knime.base.node.mine.decisiontree2.model.DecisionTreeNodeSplitPMML) TreeModel(org.dmg.pmml.TreeModelDocument.TreeModel) NodeDocument(org.dmg.pmml.NodeDocument) TreeModelDocument(org.dmg.pmml.TreeModelDocument)

Aggregations

DerivedFieldMapper (org.knime.core.node.port.pmml.preproc.DerivedFieldMapper)37 PMMLPortObject (org.knime.core.node.port.pmml.PMMLPortObject)11 PMMLPortObjectSpecCreator (org.knime.core.node.port.pmml.PMMLPortObjectSpecCreator)11 PMML (org.dmg.pmml.PMMLDocument.PMML)9 DataTableSpec (org.knime.core.data.DataTableSpec)8 DerivedField (org.dmg.pmml.DerivedFieldDocument.DerivedField)7 BufferedDataTable (org.knime.core.node.BufferedDataTable)7 PortObject (org.knime.core.node.port.PortObject)7 ArrayList (java.util.ArrayList)4 NeuralNetwork (org.dmg.pmml.NeuralNetworkDocument.NeuralNetwork)4 ColumnRearranger (org.knime.core.data.container.ColumnRearranger)4 MININGFUNCTION (org.dmg.pmml.MININGFUNCTION)3 DataColumnSpec (org.knime.core.data.DataColumnSpec)3 DataType (org.knime.core.data.DataType)3 BigInteger (java.math.BigInteger)2 HashMap (java.util.HashMap)2 SchemaType (org.apache.xmlbeans.SchemaType)2 ACTIVATIONFUNCTION (org.dmg.pmml.ACTIVATIONFUNCTION)2 ArrayType (org.dmg.pmml.ArrayType)2 ClusterDocument (org.dmg.pmml.ClusterDocument)2