Search in sources :

Example 1 with Learner

use of org.jpmml.xgboost.Learner in project jpmml-r by jpmml.

the class XGBoostConverter method encodeModel.

@Override
public MiningModel encodeModel(Schema schema) {
    RGenericVector booster = getObject();
    RNumberVector<?> ntreeLimit = (RNumberVector<?>) booster.getValue("ntreelimit", true);
    RBooleanVector compact = (RBooleanVector) booster.getValue("compact", true);
    Learner learner = ensureLearner();
    Schema xgbSchema = XGBoostUtil.toXGBoostSchema(schema);
    MiningModel miningModel = learner.encodeMiningModel((ntreeLimit != null ? ValueUtil.asInteger(ntreeLimit.asScalar()) : null), (compact != null ? compact.asScalar() : false), xgbSchema);
    return miningModel;
}
Also used : MiningModel(org.dmg.pmml.mining.MiningModel) Schema(org.jpmml.converter.Schema) Learner(org.jpmml.xgboost.Learner)

Example 2 with Learner

use of org.jpmml.xgboost.Learner in project jpmml-r by jpmml.

the class XGBoostConverter method encodeSchema.

@Override
public void encodeSchema(RExpEncoder encoder) {
    RGenericVector booster = getObject();
    RGenericVector schema = (RGenericVector) booster.getValue("schema", true);
    RVector<?> fmap;
    try {
        fmap = (RVector<?>) booster.getValue("fmap");
    } catch (IllegalArgumentException iae) {
        throw new IllegalArgumentException("No feature map information. Please initialize the \'fmap\' element");
    }
    FeatureMap featureMap;
    try {
        featureMap = loadFeatureMap(fmap);
    } catch (IOException ioe) {
        throw new IllegalArgumentException(ioe);
    }
    if (schema != null) {
        RVector<?> missing = (RVector<?>) schema.getValue("missing", true);
        if (missing != null) {
            featureMap.addMissingValue(ValueUtil.formatValue(missing.asScalar()));
        }
    }
    Learner learner = ensureLearner();
    // Dependent variable
    {
        ObjFunction obj = learner.getObj();
        FieldName targetField = FieldName.create("_target");
        List<String> targetCategories = null;
        if (schema != null) {
            RStringVector responseName = (RStringVector) schema.getValue("response_name", true);
            RStringVector responseLevels = (RStringVector) schema.getValue("response_levels", true);
            if (responseName != null) {
                targetField = FieldName.create(responseName.asScalar());
            }
            if (responseLevels != null) {
                targetCategories = responseLevels.getValues();
            }
        }
        Label label = obj.encodeLabel(targetField, targetCategories, encoder);
        encoder.setLabel(label);
    }
    // Independent variables
    {
        List<Feature> features = featureMap.encodeFeatures(encoder);
        for (Feature feature : features) {
            encoder.addFeature(feature);
        }
    }
}
Also used : Label(org.jpmml.converter.Label) IOException(java.io.IOException) Feature(org.jpmml.converter.Feature) Learner(org.jpmml.xgboost.Learner) FeatureMap(org.jpmml.xgboost.FeatureMap) List(java.util.List) FieldName(org.dmg.pmml.FieldName) ObjFunction(org.jpmml.xgboost.ObjFunction)

Aggregations

Learner (org.jpmml.xgboost.Learner)2 IOException (java.io.IOException)1 List (java.util.List)1 FieldName (org.dmg.pmml.FieldName)1 MiningModel (org.dmg.pmml.mining.MiningModel)1 Feature (org.jpmml.converter.Feature)1 Label (org.jpmml.converter.Label)1 Schema (org.jpmml.converter.Schema)1 FeatureMap (org.jpmml.xgboost.FeatureMap)1 ObjFunction (org.jpmml.xgboost.ObjFunction)1