Search in sources :

Example 6 with TreeModel

use of org.dmg.pmml.TreeModelDocument.TreeModel in project knime-core by knime.

the class PMMLDecisionTreeTranslator method exportTo.

/**
 * {@inheritDoc}
 */
@Override
public SchemaType exportTo(final PMMLDocument pmmlDoc, final PMMLPortObjectSpec spec) {
    m_nameMapper = new DerivedFieldMapper(pmmlDoc);
    PMML pmml = pmmlDoc.getPMML();
    TreeModelDocument.TreeModel treeModel = pmml.addNewTreeModel();
    PMMLMiningSchemaTranslator.writeMiningSchema(spec, treeModel);
    treeModel.setModelName("DecisionTree");
    if (m_isClassification) {
        treeModel.setFunctionName(MININGFUNCTION.CLASSIFICATION);
    } else {
        treeModel.setFunctionName(MININGFUNCTION.REGRESSION);
    }
    // set up splitCharacteristic
    if (treeIsMultisplit(m_tree.getRootNode())) {
        treeModel.setSplitCharacteristic(SplitCharacteristic.MULTI_SPLIT);
    } else {
        treeModel.setSplitCharacteristic(SplitCharacteristic.BINARY_SPLIT);
    }
    // ----------------------------------------------
    // set up missing value strategy
    PMMLMissingValueStrategy mvStrategy = m_tree.getMVStrategy() != null ? m_tree.getMVStrategy() : PMMLMissingValueStrategy.NONE;
    treeModel.setMissingValueStrategy(MV_STRATEGY_TO_PMML_MAP.get(mvStrategy));
    // -------------------------------------------------
    // set up no true child strategy
    PMMLNoTrueChildStrategy ntcStrategy = m_tree.getNTCStrategy();
    if (PMMLNoTrueChildStrategy.RETURN_LAST_PREDICTION.equals(ntcStrategy)) {
        treeModel.setNoTrueChildStrategy(NOTRUECHILDSTRATEGY.RETURN_LAST_PREDICTION);
    } else if (PMMLNoTrueChildStrategy.RETURN_NULL_PREDICTION.equals(ntcStrategy)) {
        treeModel.setNoTrueChildStrategy(NOTRUECHILDSTRATEGY.RETURN_NULL_PREDICTION);
    }
    // --------------------------------------------------
    // set up tree node
    NodeDocument.Node rootNode = treeModel.addNewNode();
    addTreeNode(rootNode, m_tree.getRootNode(), new DerivedFieldMapper(pmmlDoc));
    return TreeModel.type;
}
Also used : DerivedFieldMapper(org.knime.core.node.port.pmml.preproc.DerivedFieldMapper) Node(org.dmg.pmml.NodeDocument.Node) PMML(org.dmg.pmml.PMMLDocument.PMML) DecisionTreeNodeSplitPMML(org.knime.base.node.mine.decisiontree2.model.DecisionTreeNodeSplitPMML) TreeModel(org.dmg.pmml.TreeModelDocument.TreeModel) NodeDocument(org.dmg.pmml.NodeDocument) TreeModelDocument(org.dmg.pmml.TreeModelDocument)

Example 7 with TreeModel

use of org.dmg.pmml.TreeModelDocument.TreeModel in project knime-core by knime.

the class PMMLUtils method getFirstMiningSchema.

/**
 * Retrieves the mining schema of the first model of a specific type.
 *
 * @param pmmlDoc the PMML document to extract the mining schema from
 * @param type the type of the model
 * @return the mining schema of the first model of the given type or null if
 *         there is no model of the given type contained in the pmmlDoc
 */
public static MiningSchema getFirstMiningSchema(final PMMLDocument pmmlDoc, final SchemaType type) {
    Map<PMMLModelType, Integer> models = getNumberOfModels(pmmlDoc);
    if (!models.containsKey(PMMLModelType.getType(type))) {
        return null;
    }
    PMML pmml = pmmlDoc.getPMML();
    /*
         * Unfortunately the PMML models have no common base class. Therefore a
         * cast to the specific type is necessary for being able to add the
         * mining schema.
         */
    if (AssociationModel.type.equals(type)) {
        AssociationModel model = pmml.getAssociationModelArray(0);
        return model.getMiningSchema();
    } else if (ClusteringModel.type.equals(type)) {
        ClusteringModel model = pmml.getClusteringModelArray(0);
        return model.getMiningSchema();
    } else if (GeneralRegressionModel.type.equals(type)) {
        GeneralRegressionModel model = pmml.getGeneralRegressionModelArray(0);
        return model.getMiningSchema();
    } else if (MiningModel.type.equals(type)) {
        MiningModel model = pmml.getMiningModelArray(0);
        return model.getMiningSchema();
    } else if (NaiveBayesModel.type.equals(type)) {
        NaiveBayesModel model = pmml.getNaiveBayesModelArray(0);
        return model.getMiningSchema();
    } else if (NeuralNetwork.type.equals(type)) {
        NeuralNetwork model = pmml.getNeuralNetworkArray(0);
        return model.getMiningSchema();
    } else if (RegressionModel.type.equals(type)) {
        RegressionModel model = pmml.getRegressionModelArray(0);
        return model.getMiningSchema();
    } else if (RuleSetModel.type.equals(type)) {
        RuleSetModel model = pmml.getRuleSetModelArray(0);
        return model.getMiningSchema();
    } else if (SequenceModel.type.equals(type)) {
        SequenceModel model = pmml.getSequenceModelArray(0);
        return model.getMiningSchema();
    } else if (SupportVectorMachineModel.type.equals(type)) {
        SupportVectorMachineModel model = pmml.getSupportVectorMachineModelArray(0);
        return model.getMiningSchema();
    } else if (TextModel.type.equals(type)) {
        TextModel model = pmml.getTextModelArray(0);
        return model.getMiningSchema();
    } else if (TimeSeriesModel.type.equals(type)) {
        TimeSeriesModel model = pmml.getTimeSeriesModelArray(0);
        return model.getMiningSchema();
    } else if (TreeModel.type.equals(type)) {
        TreeModel model = pmml.getTreeModelArray(0);
        return model.getMiningSchema();
    } else {
        return null;
    }
}
Also used : RuleSetModel(org.dmg.pmml.RuleSetModelDocument.RuleSetModel) SequenceModel(org.dmg.pmml.SequenceModelDocument.SequenceModel) TextModel(org.dmg.pmml.TextModelDocument.TextModel) NaiveBayesModel(org.dmg.pmml.NaiveBayesModelDocument.NaiveBayesModel) TimeSeriesModel(org.dmg.pmml.TimeSeriesModelDocument.TimeSeriesModel) NeuralNetwork(org.dmg.pmml.NeuralNetworkDocument.NeuralNetwork) RegressionModel(org.dmg.pmml.RegressionModelDocument.RegressionModel) GeneralRegressionModel(org.dmg.pmml.GeneralRegressionModelDocument.GeneralRegressionModel) TreeModel(org.dmg.pmml.TreeModelDocument.TreeModel) MiningModel(org.dmg.pmml.MiningModelDocument.MiningModel) GeneralRegressionModel(org.dmg.pmml.GeneralRegressionModelDocument.GeneralRegressionModel) PMML(org.dmg.pmml.PMMLDocument.PMML) SupportVectorMachineModel(org.dmg.pmml.SupportVectorMachineModelDocument.SupportVectorMachineModel) AssociationModel(org.dmg.pmml.AssociationModelDocument.AssociationModel) ClusteringModel(org.dmg.pmml.ClusteringModelDocument.ClusteringModel)

Example 8 with TreeModel

use of org.dmg.pmml.TreeModelDocument.TreeModel in project knime-core by knime.

the class PMMLDecisionTreeTranslator method initializeFrom.

/**
 * {@inheritDoc}
 */
@Override
public void initializeFrom(final PMMLDocument pmmlDoc) {
    m_nameMapper = new DerivedFieldMapper(pmmlDoc);
    TreeModel[] models = pmmlDoc.getPMML().getTreeModelArray();
    if (models.length == 0) {
        throw new IllegalArgumentException("No treemodel provided.");
    }
    TreeModel treeModel = models[0];
    m_tree = parseDecTreeFromModel(treeModel);
}
Also used : DerivedFieldMapper(org.knime.core.node.port.pmml.preproc.DerivedFieldMapper) TreeModel(org.dmg.pmml.TreeModelDocument.TreeModel)

Example 9 with TreeModel

use of org.dmg.pmml.TreeModelDocument.TreeModel in project knime-core by knime.

the class PMMLMiningSchemaTranslator method writeMiningSchema.

/**
 * Writes the MiningSchema based upon the fields of the passed
 * {@link PMMLPortObjectSpec}.
 *
 * @param portSpec based upon this port object spec the mining schema is
 *            written
 * @param model the PMML model element to write the mining schema to
 */
public static void writeMiningSchema(final PMMLPortObjectSpec portSpec, final XmlObject model) {
    MiningSchema miningSchema = MiningSchema.Factory.newInstance();
    // avoid duplicate entries
    Set<String> learningNames = new HashSet<String>(portSpec.getLearningFields());
    Set<String> targetNames = new HashSet<String>(portSpec.getTargetFields());
    for (String colName : portSpec.getLearningFields()) {
        if (!targetNames.contains(colName)) {
            MiningField miningField = miningSchema.addNewMiningField();
            miningField.setName(colName);
            miningField.setInvalidValueTreatment(INVALIDVALUETREATMENTMETHOD.AS_IS);
        // don't write usageType = active (is default)
        }
    }
    // add all fields referenced in local transformations
    for (String colName : portSpec.getPreprocessingFields()) {
        if (!learningNames.contains(colName) && !targetNames.contains(colName)) {
            MiningField miningField = miningSchema.addNewMiningField();
            miningField.setName(colName);
            miningField.setInvalidValueTreatment(INVALIDVALUETREATMENTMETHOD.AS_IS);
        // don't write usageType = active (is default)
        }
    }
    // target columns = predicted
    for (String colName : portSpec.getTargetFields()) {
        MiningField miningField = miningSchema.addNewMiningField();
        miningField.setName(colName);
        miningField.setInvalidValueTreatment(INVALIDVALUETREATMENTMETHOD.AS_IS);
        miningField.setUsageType(FIELDUSAGETYPE.TARGET);
    }
    /* Unfortunately the PMML models have no common base class. Therefore
         * a cast to the specific type is necessary for being able to add the
         * mining schema. */
    SchemaType type = model.schemaType();
    if (AssociationModel.type.equals(type)) {
        ((AssociationModel) model).setMiningSchema(miningSchema);
    } else if (ClusteringModel.type.equals(type)) {
        ((ClusteringModel) model).setMiningSchema(miningSchema);
    } else if (GeneralRegressionModel.type.equals(type)) {
        ((GeneralRegressionModel) model).setMiningSchema(miningSchema);
    } else if (MiningModel.type.equals(type)) {
        ((MiningModel) model).setMiningSchema(miningSchema);
    } else if (NaiveBayesModel.type.equals(type)) {
        ((NaiveBayesModel) model).setMiningSchema(miningSchema);
    } else if (NeuralNetwork.type.equals(type)) {
        ((NeuralNetwork) model).setMiningSchema(miningSchema);
    } else if (RegressionModel.type.equals(type)) {
        ((RegressionModel) model).setMiningSchema(miningSchema);
    } else if (RuleSetModel.type.equals(type)) {
        ((RuleSetModel) model).setMiningSchema(miningSchema);
    } else if (SequenceModel.type.equals(type)) {
        ((SequenceModel) model).setMiningSchema(miningSchema);
    } else if (SupportVectorMachineModel.type.equals(type)) {
        ((SupportVectorMachineModel) model).setMiningSchema(miningSchema);
    } else if (TextModel.type.equals(type)) {
        ((TextModel) model).setMiningSchema(miningSchema);
    } else if (TimeSeriesModel.type.equals(type)) {
        ((TimeSeriesModel) model).setMiningSchema(miningSchema);
    } else if (TreeModel.type.equals(type)) {
        ((TreeModel) model).setMiningSchema(miningSchema);
    } else if (NearestNeighborModel.type.equals(type)) {
        ((NearestNeighborModel) model).setMiningSchema(miningSchema);
    }
}
Also used : SequenceModel(org.dmg.pmml.SequenceModelDocument.SequenceModel) MiningField(org.dmg.pmml.MiningFieldDocument.MiningField) TextModel(org.dmg.pmml.TextModelDocument.TextModel) NaiveBayesModel(org.dmg.pmml.NaiveBayesModelDocument.NaiveBayesModel) SchemaType(org.apache.xmlbeans.SchemaType) RegressionModel(org.dmg.pmml.RegressionModelDocument.RegressionModel) GeneralRegressionModel(org.dmg.pmml.GeneralRegressionModelDocument.GeneralRegressionModel) TreeModel(org.dmg.pmml.TreeModelDocument.TreeModel) MiningSchema(org.dmg.pmml.MiningSchemaDocument.MiningSchema) GeneralRegressionModel(org.dmg.pmml.GeneralRegressionModelDocument.GeneralRegressionModel) HashSet(java.util.HashSet) AssociationModel(org.dmg.pmml.AssociationModelDocument.AssociationModel)

Example 10 with TreeModel

use of org.dmg.pmml.TreeModelDocument.TreeModel in project knime-core by knime.

the class PMMLPortObject method moveDerivedFields.

/**
 * Moves the content of the transformation dictionary to local
 * transformations.
 * @param type the type of model to move the derived fields to
 * @return the {@link LocalTransformations} element containing the moved
 *      derived fields or an empty local transformation object if nothing
 *      has to be moved
 */
private LocalTransformations moveDerivedFields(final SchemaType type) {
    PMML pmml = m_pmmlDoc.getPMML();
    TransformationDictionary transDict = pmml.getTransformationDictionary();
    LocalTransformations localTrans = LocalTransformations.Factory.newInstance();
    if (transDict == null) {
        // nothing to be moved
        return localTrans;
    }
    localTrans.setDerivedFieldArray(transDict.getDerivedFieldArray());
    localTrans.setExtensionArray(transDict.getExtensionArray());
    /*
         * Unfortunately the PMML models have no common base class. Therefore a
         * cast to the specific type is necessary for being able to add the
         * mining schema.
         */
    boolean known = true;
    if (AssociationModel.type.equals(type)) {
        AssociationModel model = pmml.getAssociationModelArray(0);
        model.setLocalTransformations(localTrans);
    } else if (ClusteringModel.type.equals(type)) {
        ClusteringModel model = pmml.getClusteringModelArray(0);
        model.setLocalTransformations(localTrans);
    } else if (GeneralRegressionModel.type.equals(type)) {
        GeneralRegressionModel model = pmml.getGeneralRegressionModelArray(0);
        model.setLocalTransformations(localTrans);
    } else if (MiningModel.type.equals(type)) {
        MiningModel model = pmml.getMiningModelArray(0);
        model.setLocalTransformations(localTrans);
    } else if (NaiveBayesModel.type.equals(type)) {
        NaiveBayesModel model = pmml.getNaiveBayesModelArray(0);
        model.setLocalTransformations(localTrans);
    } else if (NeuralNetwork.type.equals(type)) {
        NeuralNetwork model = pmml.getNeuralNetworkArray(0);
        model.setLocalTransformations(localTrans);
    } else if (RegressionModel.type.equals(type)) {
        RegressionModel model = pmml.getRegressionModelArray(0);
        model.setLocalTransformations(localTrans);
    } else if (RuleSetModel.type.equals(type)) {
        RuleSetModel model = pmml.getRuleSetModelArray(0);
        model.setLocalTransformations(localTrans);
    } else if (SequenceModel.type.equals(type)) {
        SequenceModel model = pmml.getSequenceModelArray(0);
        model.setLocalTransformations(localTrans);
    } else if (SupportVectorMachineModel.type.equals(type)) {
        SupportVectorMachineModel model = pmml.getSupportVectorMachineModelArray(0);
        model.setLocalTransformations(localTrans);
    } else if (TextModel.type.equals(type)) {
        TextModel model = pmml.getTextModelArray(0);
        model.setLocalTransformations(localTrans);
    } else if (TimeSeriesModel.type.equals(type)) {
        TimeSeriesModel model = pmml.getTimeSeriesModelArray(0);
        model.setLocalTransformations(localTrans);
    } else if (TreeModel.type.equals(type)) {
        TreeModel model = pmml.getTreeModelArray(0);
        model.setLocalTransformations(localTrans);
    } else {
        if (type != null) {
            LOGGER.error("Could not move TransformationDictionary to " + "unsupported model of type \"" + type + "\".");
        }
        known = false;
    }
    if (known) {
        // remove derived fields from TransformationDictionary
        transDict.setDerivedFieldArray(new DerivedField[0]);
        transDict.setExtensionArray(new ExtensionDocument.Extension[0]);
    }
    return localTrans;
}
Also used : RuleSetModel(org.dmg.pmml.RuleSetModelDocument.RuleSetModel) SequenceModel(org.dmg.pmml.SequenceModelDocument.SequenceModel) TransformationDictionary(org.dmg.pmml.TransformationDictionaryDocument.TransformationDictionary) TextModel(org.dmg.pmml.TextModelDocument.TextModel) ExtensionDocument(org.dmg.pmml.ExtensionDocument) NaiveBayesModel(org.dmg.pmml.NaiveBayesModelDocument.NaiveBayesModel) TimeSeriesModel(org.dmg.pmml.TimeSeriesModelDocument.TimeSeriesModel) NeuralNetwork(org.dmg.pmml.NeuralNetworkDocument.NeuralNetwork) GeneralRegressionModel(org.dmg.pmml.GeneralRegressionModelDocument.GeneralRegressionModel) RegressionModel(org.dmg.pmml.RegressionModelDocument.RegressionModel) TreeModel(org.dmg.pmml.TreeModelDocument.TreeModel) LocalTransformations(org.dmg.pmml.LocalTransformationsDocument.LocalTransformations) MiningModel(org.dmg.pmml.MiningModelDocument.MiningModel) GeneralRegressionModel(org.dmg.pmml.GeneralRegressionModelDocument.GeneralRegressionModel) PMML(org.dmg.pmml.PMMLDocument.PMML) SupportVectorMachineModel(org.dmg.pmml.SupportVectorMachineModelDocument.SupportVectorMachineModel) AssociationModel(org.dmg.pmml.AssociationModelDocument.AssociationModel) ClusteringModel(org.dmg.pmml.ClusteringModelDocument.ClusteringModel)

Aggregations

TreeModel (org.dmg.pmml.TreeModelDocument.TreeModel)11 PMML (org.dmg.pmml.PMMLDocument.PMML)6 GeneralRegressionModel (org.dmg.pmml.GeneralRegressionModelDocument.GeneralRegressionModel)5 RegressionModel (org.dmg.pmml.RegressionModelDocument.RegressionModel)5 ClusteringModel (org.dmg.pmml.ClusteringModelDocument.ClusteringModel)4 NaiveBayesModel (org.dmg.pmml.NaiveBayesModelDocument.NaiveBayesModel)4 NeuralNetwork (org.dmg.pmml.NeuralNetworkDocument.NeuralNetwork)4 RuleSetModel (org.dmg.pmml.RuleSetModelDocument.RuleSetModel)4 SupportVectorMachineModel (org.dmg.pmml.SupportVectorMachineModelDocument.SupportVectorMachineModel)4 AssociationModel (org.dmg.pmml.AssociationModelDocument.AssociationModel)3 SequenceModel (org.dmg.pmml.SequenceModelDocument.SequenceModel)3 TextModel (org.dmg.pmml.TextModelDocument.TextModel)3 DerivedFieldMapper (org.knime.core.node.port.pmml.preproc.DerivedFieldMapper)3 SchemaType (org.apache.xmlbeans.SchemaType)2 LocalTransformations (org.dmg.pmml.LocalTransformationsDocument.LocalTransformations)2 MiningModel (org.dmg.pmml.MiningModelDocument.MiningModel)2 TimeSeriesModel (org.dmg.pmml.TimeSeriesModelDocument.TimeSeriesModel)2 TransformationDictionary (org.dmg.pmml.TransformationDictionaryDocument.TransformationDictionary)2 TreeModelDocument (org.dmg.pmml.TreeModelDocument)2 HashSet (java.util.HashSet)1