Search in sources :

Example 1 with Segmentation

use of org.dmg.pmml.mining.Segmentation in project pyramid by cheng-li.

the class MiningModelUtil method createModelChain.

public static MiningModel createModelChain(List<? extends Model> models, Schema schema) {
    if (models.size() < 1) {
        throw new IllegalArgumentException();
    }
    Segmentation segmentation = createSegmentation(Segmentation.MultipleModelMethod.MODEL_CHAIN, models);
    Model lastModel = Iterables.getLast(models);
    MiningModel miningModel = new MiningModel(lastModel.getMiningFunction(), ModelUtil.createMiningSchema(schema.getLabel())).setMathContext(ModelUtil.simplifyMathContext(lastModel.getMathContext())).setSegmentation(segmentation);
    return miningModel;
}
Also used : Segmentation(org.dmg.pmml.mining.Segmentation) MiningModel(org.dmg.pmml.mining.MiningModel) Model(org.dmg.pmml.Model) MiningModel(org.dmg.pmml.mining.MiningModel) RegressionModel(org.dmg.pmml.regression.RegressionModel)

Example 2 with Segmentation

use of org.dmg.pmml.mining.Segmentation in project pyramid by cheng-li.

the class MiningModelUtil method createSegmentation.

public static Segmentation createSegmentation(Segmentation.MultipleModelMethod multipleModelMethod, List<? extends Model> models, List<? extends Number> weights) {
    if ((weights != null) && (models.size() != weights.size())) {
        throw new IllegalArgumentException();
    }
    List<Segment> segments = new ArrayList<>();
    for (int i = 0; i < models.size(); i++) {
        Model model = models.get(i);
        Number weight = (weights != null ? weights.get(i) : null);
        Segment segment = new Segment().setId(String.valueOf(i + 1)).setPredicate(new True()).setModel(model);
        if (weight != null && !ValueUtil.isOne(weight)) {
            segment.setWeight(ValueUtil.asDouble(weight));
        }
        segments.add(segment);
    }
    return new Segmentation(multipleModelMethod, segments);
}
Also used : Segmentation(org.dmg.pmml.mining.Segmentation) ArrayList(java.util.ArrayList) Model(org.dmg.pmml.Model) MiningModel(org.dmg.pmml.mining.MiningModel) RegressionModel(org.dmg.pmml.regression.RegressionModel) True(org.dmg.pmml.True) Segment(org.dmg.pmml.mining.Segment)

Example 3 with Segmentation

use of org.dmg.pmml.mining.Segmentation in project shifu by ShifuML.

the class PMMLTranslator method build.

public PMML build(List<BasicML> basicMLs) {
    if (basicMLs == null || basicMLs.size() == 0) {
        throw new IllegalArgumentException("Input ml model list is empty.");
    }
    PMML pmml = new PMML();
    // create and add header
    Header header = new Header();
    pmml.setHeader(header);
    header.setCopyright(" Copyright [2013-2018] PayPal Software Foundation\n" + "\n" + " Licensed under the Apache License, Version 2.0 (the \"License\");\n" + " you may not use this file except in compliance with the License.\n" + " You may obtain a copy of the License at\n" + "\n" + "    http://www.apache.org/licenses/LICENSE-2.0\n" + "\n" + " Unless required by applicable law or agreed to in writing, software\n" + " distributed under the License is distributed on an \"AS IS\" BASIS,\n" + " WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n" + " See the License for the specific language governing permissions and\n" + " limitations under the License.\n");
    Application application = new Application();
    header.setApplication(application);
    application.setName("shifu");
    String findContainingJar = JarManager.findContainingJar(TreeEnsemblePMMLTranslator.class);
    JarFile jar = null;
    try {
        jar = new JarFile(findContainingJar);
        final Manifest manifest = jar.getManifest();
        String version = manifest.getMainAttributes().getValue("version");
        application.setVersion(version);
    } catch (Exception e) {
        LOG.warn(e.getMessage());
    } finally {
        if (jar != null) {
            try {
                jar.close();
            } catch (IOException e) {
                LOG.warn(e.getMessage());
            }
        }
    }
    // create and set data dictionary for all bagging models
    pmml.setDataDictionary(this.dataDictionaryCreator.build(null));
    if (isOutBaggingToOne) {
        MiningModel miningModel = new MiningModel();
        miningModel.setMiningSchema(this.miningSchemaCreator.build(null));
        miningModel.setMiningFunction(MiningFunction.fromValue("regression"));
        miningModel.setTargets(((NNPmmlModelCreator) this.modelCreator).createTargets());
        AbstractSpecifCreator minningModelCreator = new MiningModelPmmlCreator(this.specifCreator.getModelConfig(), this.specifCreator.getColumnConfigList());
        minningModelCreator.build(null, miningModel);
        Segmentation seg = new Segmentation();
        miningModel.setSegmentation(seg);
        seg.setMultipleModelMethod(MultipleModelMethod.AVERAGE);
        List<Segment> list = seg.getSegments();
        int idCount = 0;
        for (BasicML basicML : basicMLs) {
            // create model element
            Model tmpmodel = this.modelCreator.build(basicML);
            // create mining schema
            tmpmodel.setMiningSchema(this.miningSchemaCreator.build(basicML));
            // create variable statistical info
            tmpmodel.setModelStats(this.modelStatsCreator.build(basicML));
            // create variable transform
            tmpmodel.setLocalTransformations(this.localTransformationsCreator.build(basicML));
            this.specifCreator.build(basicML, tmpmodel, idCount);
            Segment segment = new Segment();
            segment.setId("Segement" + String.valueOf(idCount));
            segment.setPredicate(new True());
            segment.setModel(tmpmodel);
            list.add(segment);
            idCount++;
        }
        List<Model> models = pmml.getModels();
        models.add(miningModel);
    } else {
        BasicML basicML = basicMLs.get(0);
        // create model element
        Model model = this.modelCreator.build(basicML);
        // create mining schema
        model.setMiningSchema(this.miningSchemaCreator.build(basicML));
        // create variable statistical info
        model.setModelStats(this.modelStatsCreator.build(basicML));
        // create variable transform
        model.setLocalTransformations(this.localTransformationsCreator.build(basicML));
        this.specifCreator.build(basicML, model);
        pmml.addModels(model);
    }
    return pmml;
}
Also used : Segmentation(org.dmg.pmml.mining.Segmentation) AbstractSpecifCreator(ml.shifu.shifu.core.pmml.builder.creator.AbstractSpecifCreator) True(org.dmg.pmml.True) BasicML(org.encog.ml.BasicML) IOException(java.io.IOException) JarFile(java.util.jar.JarFile) Manifest(java.util.jar.Manifest) IOException(java.io.IOException) Segment(org.dmg.pmml.mining.Segment) MiningModelPmmlCreator(ml.shifu.shifu.core.pmml.builder.impl.MiningModelPmmlCreator) Header(org.dmg.pmml.Header) MiningModel(org.dmg.pmml.mining.MiningModel) Model(org.dmg.pmml.Model) MiningModel(org.dmg.pmml.mining.MiningModel) PMML(org.dmg.pmml.PMML) Application(org.dmg.pmml.Application)

Example 4 with Segmentation

use of org.dmg.pmml.mining.Segmentation in project jpmml-sparkml by jpmml.

the class ModelConverter method getLastModel.

protected org.dmg.pmml.Model getLastModel(org.dmg.pmml.Model model) {
    if (model instanceof MiningModel) {
        MiningModel miningModel = (MiningModel) model;
        Segmentation segmentation = miningModel.getSegmentation();
        MultipleModelMethod multipleModelMethod = segmentation.getMultipleModelMethod();
        switch(multipleModelMethod) {
            case MODEL_CHAIN:
                List<Segment> segments = segmentation.getSegments();
                if (segments.size() > 0) {
                    Segment lastSegment = segments.get(segments.size() - 1);
                    return lastSegment.getModel();
                }
                break;
            default:
                break;
        }
    }
    return model;
}
Also used : MiningModel(org.dmg.pmml.mining.MiningModel) Segmentation(org.dmg.pmml.mining.Segmentation) MultipleModelMethod(org.dmg.pmml.mining.Segmentation.MultipleModelMethod) Segment(org.dmg.pmml.mining.Segment)

Example 5 with Segmentation

use of org.dmg.pmml.mining.Segmentation in project shifu by ShifuML.

the class TreeEnsemblePmmlCreator method convert.

public MiningModel convert(IndependentTreeModel treeModel) {
    MiningModel gbt = new MiningModel();
    MiningSchema miningSchema = new TreeModelMiningSchemaCreator(this.modelConfig, this.columnConfigList).build(null);
    gbt.setMiningSchema(miningSchema);
    if (treeModel.isClassification()) {
        gbt.setMiningFunction(MiningFunction.fromValue("classification"));
    } else {
        gbt.setMiningFunction(MiningFunction.fromValue("regression"));
    }
    gbt.setTargets(createTargets(this.modelConfig));
    Segmentation seg = new Segmentation();
    gbt.setSegmentation(seg);
    seg.setMultipleModelMethod(MultipleModelMethod.fromValue("weightedAverage"));
    List<Segment> list = seg.getSegments();
    int idCount = 0;
    // such case we only support treeModel is one element list
    if (treeModel.getTrees().size() != 1) {
        throw new RuntimeException("Bagging model cannot be supported in PMML generation.");
    }
    for (TreeNode tn : treeModel.getTrees().get(0)) {
        TreeNodePmmlElementCreator tnec = new TreeNodePmmlElementCreator(this.modelConfig, this.columnConfigList, treeModel);
        org.dmg.pmml.tree.Node root = tnec.convert(tn.getNode());
        TreeModelPmmlElementCreator tmec = new TreeModelPmmlElementCreator(this.modelConfig, this.columnConfigList);
        org.dmg.pmml.tree.TreeModel tm = tmec.convert(treeModel, root);
        tm.setModelName(String.valueOf(idCount));
        Segment segment = new Segment();
        if (treeModel.isGBDT()) {
            segment.setWeight(treeModel.getWeights().get(0).get(idCount) * treeModel.getTrees().size());
        } else {
            segment.setWeight(treeModel.getWeights().get(0).get(idCount));
        }
        segment.setId("Segement" + String.valueOf(idCount));
        idCount++;
        segment.setPredicate(new True());
        segment.setModel(tm);
        list.add(segment);
    }
    return gbt;
}
Also used : Segmentation(org.dmg.pmml.mining.Segmentation) True(org.dmg.pmml.True) Segment(org.dmg.pmml.mining.Segment) MiningModel(org.dmg.pmml.mining.MiningModel) MiningSchema(org.dmg.pmml.MiningSchema) TreeNode(ml.shifu.shifu.core.dtrain.dt.TreeNode)

Aggregations

MiningModel (org.dmg.pmml.mining.MiningModel)6 Segmentation (org.dmg.pmml.mining.Segmentation)6 Segment (org.dmg.pmml.mining.Segment)5 Model (org.dmg.pmml.Model)3 True (org.dmg.pmml.True)3 MiningSchema (org.dmg.pmml.MiningSchema)2 RegressionModel (org.dmg.pmml.regression.RegressionModel)2 IOException (java.io.IOException)1 ArrayList (java.util.ArrayList)1 JarFile (java.util.jar.JarFile)1 Manifest (java.util.jar.Manifest)1 TreeNode (ml.shifu.shifu.core.dtrain.dt.TreeNode)1 AbstractSpecifCreator (ml.shifu.shifu.core.pmml.builder.creator.AbstractSpecifCreator)1 MiningModelPmmlCreator (ml.shifu.shifu.core.pmml.builder.impl.MiningModelPmmlCreator)1 Application (org.dmg.pmml.Application)1 DataField (org.dmg.pmml.DataField)1 Header (org.dmg.pmml.Header)1 MiningField (org.dmg.pmml.MiningField)1 Output (org.dmg.pmml.Output)1 OutputField (org.dmg.pmml.OutputField)1