Search in sources :

Example 26 with TreeModel

use of org.dmg.pmml.tree.TreeModel in project jpmml-sparkml by jpmml.

the class RandomForestRegressionModelConverter method encodeModel.

@Override
public MiningModel encodeModel(Schema schema) {
    RandomForestRegressionModel model = getTransformer();
    List<TreeModel> treeModels = TreeModelUtil.encodeDecisionTreeEnsemble(model, schema);
    MiningModel miningModel = new MiningModel(MiningFunction.REGRESSION, ModelUtil.createMiningSchema(schema.getLabel())).setSegmentation(MiningModelUtil.createSegmentation(Segmentation.MultipleModelMethod.AVERAGE, treeModels));
    return miningModel;
}
Also used : TreeModel(org.dmg.pmml.tree.TreeModel) RandomForestRegressionModel(org.apache.spark.ml.regression.RandomForestRegressionModel) MiningModel(org.dmg.pmml.mining.MiningModel)

Example 27 with TreeModel

use of org.dmg.pmml.tree.TreeModel in project jpmml-sparkml by jpmml.

the class TreeModelUtil method encodeTreeModel.

public static TreeModel encodeTreeModel(org.apache.spark.ml.tree.Node node, PredicateManager predicateManager, MiningFunction miningFunction, Schema schema) {
    Node root = encodeNode(node, predicateManager, Collections.<FieldName, Set<String>>emptyMap(), miningFunction, schema).setPredicate(new True());
    TreeModel treeModel = new TreeModel(miningFunction, ModelUtil.createMiningSchema(schema.getLabel()), root).setSplitCharacteristic(TreeModel.SplitCharacteristic.BINARY_SPLIT);
    String compact = TreeModelOptions.COMPACT;
    if (compact != null && Boolean.valueOf(compact)) {
        Visitor visitor = new TreeModelCompactor();
        visitor.applyTo(treeModel);
    }
    return treeModel;
}
Also used : TreeModel(org.dmg.pmml.tree.TreeModel) DecisionTreeModel(org.apache.spark.ml.tree.DecisionTreeModel) HashSet(java.util.HashSet) Set(java.util.Set) Visitor(org.dmg.pmml.Visitor) InternalNode(org.apache.spark.ml.tree.InternalNode) Node(org.dmg.pmml.tree.Node) LeafNode(org.apache.spark.ml.tree.LeafNode) True(org.dmg.pmml.True) TreeModelCompactor(org.jpmml.sparkml.visitors.TreeModelCompactor) FieldName(org.dmg.pmml.FieldName)

Example 28 with TreeModel

use of org.dmg.pmml.tree.TreeModel in project jpmml-sparkml by jpmml.

the class TreeModelUtil method encodeDecisionTreeEnsemble.

public static <M extends Model<M> & TreeEnsembleModel<T>, T extends Model<T> & DecisionTreeModel> List<TreeModel> encodeDecisionTreeEnsemble(M model, PredicateManager predicateManager, Schema schema) {
    Schema segmentSchema = schema.toAnonymousSchema();
    List<TreeModel> treeModels = new ArrayList<>();
    T[] trees = model.trees();
    for (T tree : trees) {
        TreeModel treeModel = encodeDecisionTree(tree, predicateManager, segmentSchema);
        treeModels.add(treeModel);
    }
    return treeModels;
}
Also used : TreeModel(org.dmg.pmml.tree.TreeModel) DecisionTreeModel(org.apache.spark.ml.tree.DecisionTreeModel) Schema(org.jpmml.converter.Schema) ArrayList(java.util.ArrayList)

Example 29 with TreeModel

use of org.dmg.pmml.tree.TreeModel in project jpmml-sparkml by jpmml.

the class GBTClassificationModelConverter method encodeModel.

@Override
public MiningModel encodeModel(Schema schema) {
    GBTClassificationModel model = getTransformer();
    String lossType = model.getLossType();
    switch(lossType) {
        case "logistic":
            break;
        default:
            throw new IllegalArgumentException("Loss function " + lossType + " is not supported");
    }
    Schema segmentSchema = new Schema(new ContinuousLabel(null, DataType.DOUBLE), schema.getFeatures());
    List<TreeModel> treeModels = TreeModelUtil.encodeDecisionTreeEnsemble(model, segmentSchema);
    MiningModel miningModel = new MiningModel(MiningFunction.REGRESSION, ModelUtil.createMiningSchema(segmentSchema.getLabel())).setSegmentation(MiningModelUtil.createSegmentation(Segmentation.MultipleModelMethod.WEIGHTED_SUM, treeModels, Doubles.asList(model.treeWeights()))).setOutput(ModelUtil.createPredictedOutput(FieldName.create("gbtValue"), OpType.CONTINUOUS, DataType.DOUBLE));
    return MiningModelUtil.createBinaryLogisticClassification(miningModel, 2d, 0d, RegressionModel.NormalizationMethod.LOGIT, false, schema);
}
Also used : TreeModel(org.dmg.pmml.tree.TreeModel) MiningModel(org.dmg.pmml.mining.MiningModel) GBTClassificationModel(org.apache.spark.ml.classification.GBTClassificationModel) Schema(org.jpmml.converter.Schema) ContinuousLabel(org.jpmml.converter.ContinuousLabel)

Example 30 with TreeModel

use of org.dmg.pmml.tree.TreeModel in project drools by kiegroup.

the class TreeModelImplementationProviderTest method getKiePMMLModelWithSources.

@Test
public void getKiePMMLModelWithSources() {
    TreeModel treeModel = (TreeModel) pmml.getModels().get(0);
    final CommonCompilationDTO<TreeModel> compilationDTO = CommonCompilationDTO.fromGeneratedPackageNameAndFields(PACKAGE_NAME, pmml, treeModel, new HasClassLoaderMock());
    final KiePMMLModelWithSources retrieved = PROVIDER.getKiePMMLModelWithSources(compilationDTO);
    assertNotNull(retrieved);
    final Map<String, String> sourcesMap = retrieved.getSourcesMap();
    assertNotNull(sourcesMap);
    assertFalse(sourcesMap.isEmpty());
    ClassLoader classLoader = Thread.currentThread().getContextClassLoader();
    try {
        final Map<String, Class<?>> compiled = KieMemoryCompiler.compile(sourcesMap, classLoader);
        for (Class<?> clazz : compiled.values()) {
            assertTrue(clazz instanceof Serializable);
        }
    } catch (Throwable t) {
        fail(t.getMessage());
    }
}
Also used : TreeModel(org.dmg.pmml.tree.TreeModel) KiePMMLTreeModel(org.kie.pmml.models.tree.model.KiePMMLTreeModel) KiePMMLModelWithSources(org.kie.pmml.commons.model.KiePMMLModelWithSources) Serializable(java.io.Serializable) BeforeClass(org.junit.BeforeClass) HasClassLoaderMock(org.kie.pmml.compiler.commons.mocks.HasClassLoaderMock) Test(org.junit.Test)

Aggregations

TreeModel (org.dmg.pmml.tree.TreeModel)48 MiningModel (org.dmg.pmml.mining.MiningModel)17 Node (org.dmg.pmml.tree.Node)12 Test (org.junit.Test)12 ArrayList (java.util.ArrayList)11 BranchNode (org.dmg.pmml.tree.BranchNode)9 LeafNode (org.dmg.pmml.tree.LeafNode)9 Schema (org.jpmml.converter.Schema)9 ClassifierNode (org.dmg.pmml.tree.ClassifierNode)8 CategoricalLabel (org.jpmml.converter.CategoricalLabel)8 KiePMMLTreeModel (org.kie.pmml.models.drools.tree.model.KiePMMLTreeModel)8 KnowledgeBuilderImpl (org.drools.compiler.builder.impl.KnowledgeBuilderImpl)6 HasClassLoaderMock (org.kie.pmml.compiler.commons.mocks.HasClassLoaderMock)6 PMML (org.dmg.pmml.PMML)5 HasKnowledgeBuilderMock (org.kie.pmml.models.drools.commons.implementations.HasKnowledgeBuilderMock)5 KiePMMLTreeModel (org.kie.pmml.models.tree.model.KiePMMLTreeModel)5 ConstructorDeclaration (com.github.javaparser.ast.body.ConstructorDeclaration)4 Expression (com.github.javaparser.ast.expr.Expression)4 NameExpr (com.github.javaparser.ast.expr.NameExpr)4 StringLiteralExpr (com.github.javaparser.ast.expr.StringLiteralExpr)4