use of org.dmg.pmml.mining.MiningModel in project jpmml-r by jpmml.
the class GBMConverter method encodeModel.
@Override
public MiningModel encodeModel(Schema schema) {
RGenericVector gbm = getObject();
RDoubleVector initF = (RDoubleVector) gbm.getValue("initF");
RGenericVector trees = (RGenericVector) gbm.getValue("trees");
RGenericVector c_splits = (RGenericVector) gbm.getValue("c.splits");
RGenericVector distribution = (RGenericVector) gbm.getValue("distribution");
RStringVector distributionName = (RStringVector) distribution.getValue("name");
Schema segmentSchema = new Schema(new ContinuousLabel(null, DataType.DOUBLE), schema.getFeatures());
List<TreeModel> treeModels = new ArrayList<>();
for (int i = 0; i < trees.size(); i++) {
RGenericVector tree = (RGenericVector) trees.getValue(i);
TreeModel treeModel = encodeTreeModel(MiningFunction.REGRESSION, tree, c_splits, segmentSchema);
treeModels.add(treeModel);
}
MiningModel miningModel = encodeMiningModel(distributionName, treeModels, initF.asScalar(), schema);
return miningModel;
}
use of org.dmg.pmml.mining.MiningModel in project jpmml-r by jpmml.
the class RangerConverter method encodeClassification.
private MiningModel encodeClassification(RGenericVector ranger, Schema schema) {
RGenericVector forest = (RGenericVector) ranger.getValue("forest");
final RStringVector levels = (RStringVector) forest.getValue("levels");
ScoreEncoder scoreEncoder = new ScoreEncoder() {
@Override
public void encode(Node node, Number splitValue, RNumberVector<?> terminalClassCount) {
int index = ValueUtil.asInt(splitValue);
if (terminalClassCount != null) {
throw new IllegalArgumentException();
}
node.setScore(levels.getValue(index - 1));
}
};
List<TreeModel> treeModels = encodeForest(forest, MiningFunction.CLASSIFICATION, scoreEncoder, schema);
MiningModel miningModel = new MiningModel(MiningFunction.CLASSIFICATION, ModelUtil.createMiningSchema(schema.getLabel())).setSegmentation(MiningModelUtil.createSegmentation(Segmentation.MultipleModelMethod.MAJORITY_VOTE, treeModels));
return miningModel;
}
use of org.dmg.pmml.mining.MiningModel in project jpmml-sparkml by jpmml.
the class GBTRegressionModelConverter method encodeModel.
@Override
public MiningModel encodeModel(Schema schema) {
GBTRegressionModel model = getTransformer();
List<TreeModel> treeModels = TreeModelUtil.encodeDecisionTreeEnsemble(model, schema);
MiningModel miningModel = new MiningModel(MiningFunction.REGRESSION, ModelUtil.createMiningSchema(schema.getLabel())).setSegmentation(MiningModelUtil.createSegmentation(Segmentation.MultipleModelMethod.WEIGHTED_SUM, treeModels, Doubles.asList(model.treeWeights())));
return miningModel;
}
use of org.dmg.pmml.mining.MiningModel in project jpmml-r by jpmml.
the class GBMConverter method encodeBinaryClassification.
private MiningModel encodeBinaryClassification(List<TreeModel> treeModels, Double initF, double coefficient, Schema schema) {
Schema segmentSchema = new Schema(new ContinuousLabel(null, DataType.DOUBLE), schema.getFeatures());
MiningModel miningModel = createMiningModel(treeModels, initF, segmentSchema).setOutput(ModelUtil.createPredictedOutput(FieldName.create("gbmValue"), OpType.CONTINUOUS, DataType.DOUBLE));
return MiningModelUtil.createBinaryLogisticClassification(miningModel, -coefficient, 0d, RegressionModel.NormalizationMethod.LOGIT, true, schema);
}
use of org.dmg.pmml.mining.MiningModel in project jpmml-r by jpmml.
the class GBMConverter method encodeMultinomialClassification.
private MiningModel encodeMultinomialClassification(List<TreeModel> treeModels, Double initF, Schema schema) {
CategoricalLabel categoricalLabel = (CategoricalLabel) schema.getLabel();
Schema segmentSchema = new Schema(new ContinuousLabel(null, DataType.DOUBLE), schema.getFeatures());
List<Model> miningModels = new ArrayList<>();
for (int i = 0, columns = categoricalLabel.size(), rows = (treeModels.size() / columns); i < columns; i++) {
MiningModel miningModel = createMiningModel(CMatrixUtil.getColumn(treeModels, rows, columns, i), initF, segmentSchema).setOutput(ModelUtil.createPredictedOutput(FieldName.create("gbmValue(" + categoricalLabel.getValue(i) + ")"), OpType.CONTINUOUS, DataType.DOUBLE));
miningModels.add(miningModel);
}
return MiningModelUtil.createClassification(miningModels, RegressionModel.NormalizationMethod.SOFTMAX, true, schema);
}
Aggregations