use of org.dmg.pmml.tree.TreeModel in project drools by kiegroup.
the class KiePMMLTreeModelFactoryTest method getKiePMMLTreeModel.
@Test
public void getKiePMMLTreeModel() throws InstantiationException, IllegalAccessException {
final Map<String, KiePMMLOriginalTypeGeneratedType> fieldTypeMap = getFieldTypeMap(pmml.getDataDictionary(), pmml.getTransformationDictionary(), treeModel.getLocalTransformations());
KnowledgeBuilderImpl knowledgeBuilder = new KnowledgeBuilderImpl();
final CommonCompilationDTO<TreeModel> compilationDTO = CommonCompilationDTO.fromGeneratedPackageNameAndFields(PACKAGE_NAME, pmml, treeModel, new HasKnowledgeBuilderMock(knowledgeBuilder));
final DroolsCompilationDTO<TreeModel> droolsCompilationDTO = DroolsCompilationDTO.fromCompilationDTO(compilationDTO, fieldTypeMap);
KiePMMLTreeModel retrieved = KiePMMLTreeModelFactory.getKiePMMLTreeModel(droolsCompilationDTO);
assertNotNull(retrieved);
assertEquals(treeModel.getModelName(), retrieved.getName());
assertEquals(TARGET_FIELD, retrieved.getTargetField());
}
use of org.dmg.pmml.tree.TreeModel in project drools by kiegroup.
the class KiePMMLTreeModelFactoryTest method getKiePMMLScorecardModelSourcesMap.
@Test
public void getKiePMMLScorecardModelSourcesMap() {
final Map<String, KiePMMLOriginalTypeGeneratedType> fieldTypeMap = getFieldTypeMap(pmml.getDataDictionary(), pmml.getTransformationDictionary(), treeModel.getLocalTransformations());
KnowledgeBuilderImpl knowledgeBuilder = new KnowledgeBuilderImpl();
final CommonCompilationDTO<TreeModel> compilationDTO = CommonCompilationDTO.fromGeneratedPackageNameAndFields(PACKAGE_NAME, pmml, treeModel, new HasKnowledgeBuilderMock(knowledgeBuilder));
final DroolsCompilationDTO<TreeModel> droolsCompilationDTO = DroolsCompilationDTO.fromCompilationDTO(compilationDTO, fieldTypeMap);
Map<String, String> retrieved = KiePMMLTreeModelFactory.getKiePMMLTreeModelSourcesMap(droolsCompilationDTO);
assertNotNull(retrieved);
assertEquals(1, retrieved.size());
}
use of org.dmg.pmml.tree.TreeModel in project jpmml-r by jpmml.
the class RandomForestConverter method encodeTreeModel.
private <P extends Number> TreeModel encodeTreeModel(MiningFunction miningFunction, ScoreEncoder<P> scoreEncoder, List<? extends Number> leftDaughter, List<? extends Number> rightDaughter, List<P> nodepred, List<? extends Number> bestvar, List<Double> xbestsplit, Schema schema) {
RGenericVector randomForest = getObject();
Node root = encodeNode(True.INSTANCE, 0, scoreEncoder, leftDaughter, rightDaughter, bestvar, xbestsplit, nodepred, new CategoryManager(), schema);
TreeModel treeModel = new TreeModel(miningFunction, ModelUtil.createMiningSchema(schema.getLabel()), root).setMissingValueStrategy(TreeModel.MissingValueStrategy.NULL_PREDICTION).setSplitCharacteristic(TreeModel.SplitCharacteristic.BINARY_SPLIT);
if (this.compact) {
Visitor visitor = new RandomForestCompactor();
visitor.applyTo(treeModel);
}
return treeModel;
}
use of org.dmg.pmml.tree.TreeModel in project jpmml-r by jpmml.
the class RandomForestConverter method encodeRegression.
private MiningModel encodeRegression(RGenericVector forest, Schema schema) {
RNumberVector<?> leftDaughter = forest.getNumericElement("leftDaughter");
RNumberVector<?> rightDaughter = forest.getNumericElement("rightDaughter");
RDoubleVector nodepred = forest.getDoubleElement("nodepred");
RNumberVector<?> bestvar = forest.getNumericElement("bestvar");
RDoubleVector xbestsplit = forest.getDoubleElement("xbestsplit");
RIntegerVector nrnodes = forest.getIntegerElement("nrnodes");
RNumberVector<?> ntree = forest.getNumericElement("ntree");
ScoreEncoder<Double> scoreEncoder = new ScoreEncoder<Double>() {
@Override
public Double encode(Double value) {
return value;
}
};
int rows = nrnodes.asScalar();
int columns = ValueUtil.asInt(ntree.asScalar());
Schema segmentSchema = schema.toAnonymousSchema();
List<TreeModel> treeModels = new ArrayList<>();
for (int i = 0; i < columns; i++) {
TreeModel treeModel = encodeTreeModel(MiningFunction.REGRESSION, scoreEncoder, FortranMatrixUtil.getColumn(leftDaughter.getValues(), rows, columns, i), FortranMatrixUtil.getColumn(rightDaughter.getValues(), rows, columns, i), FortranMatrixUtil.getColumn(nodepred.getValues(), rows, columns, i), FortranMatrixUtil.getColumn(bestvar.getValues(), rows, columns, i), FortranMatrixUtil.getColumn(xbestsplit.getValues(), rows, columns, i), segmentSchema);
treeModels.add(treeModel);
}
MiningModel miningModel = new MiningModel(MiningFunction.REGRESSION, ModelUtil.createMiningSchema(schema.getLabel())).setSegmentation(MiningModelUtil.createSegmentation(Segmentation.MultipleModelMethod.AVERAGE, treeModels));
return miningModel;
}
use of org.dmg.pmml.tree.TreeModel in project jpmml-r by jpmml.
the class RandomForestConverter method encodeClassification.
private MiningModel encodeClassification(RGenericVector forest, Schema schema) {
RNumberVector<?> bestvar = forest.getNumericElement("bestvar");
RNumberVector<?> treemap = forest.getNumericElement("treemap");
RIntegerVector nodepred = forest.getIntegerElement("nodepred");
RDoubleVector xbestsplit = forest.getDoubleElement("xbestsplit");
RIntegerVector nrnodes = forest.getIntegerElement("nrnodes");
RDoubleVector ntree = forest.getDoubleElement("ntree");
int rows = nrnodes.asScalar();
int columns = ValueUtil.asInt(ntree.asScalar());
CategoricalLabel categoricalLabel = (CategoricalLabel) schema.getLabel();
ScoreEncoder<Integer> scoreEncoder = new ScoreEncoder<Integer>() {
@Override
public Object encode(Integer value) {
return categoricalLabel.getValue(value - 1);
}
};
Schema segmentSchema = schema.toAnonymousSchema();
List<TreeModel> treeModels = new ArrayList<>();
for (int i = 0; i < columns; i++) {
List<? extends Number> daughters = FortranMatrixUtil.getColumn(treemap.getValues(), 2 * rows, columns, i);
TreeModel treeModel = encodeTreeModel(MiningFunction.CLASSIFICATION, scoreEncoder, FortranMatrixUtil.getColumn(daughters, rows, 2, 0), FortranMatrixUtil.getColumn(daughters, rows, 2, 1), FortranMatrixUtil.getColumn(nodepred.getValues(), rows, columns, i), FortranMatrixUtil.getColumn(bestvar.getValues(), rows, columns, i), FortranMatrixUtil.getColumn(xbestsplit.getValues(), rows, columns, i), segmentSchema);
treeModels.add(treeModel);
}
MiningModel miningModel = new MiningModel(MiningFunction.CLASSIFICATION, ModelUtil.createMiningSchema(categoricalLabel)).setSegmentation(MiningModelUtil.createSegmentation(Segmentation.MultipleModelMethod.MAJORITY_VOTE, treeModels)).setOutput(ModelUtil.createProbabilityOutput(DataType.DOUBLE, categoricalLabel));
return miningModel;
}
Aggregations