use of org.dmg.pmml.mining.Segmentation in project pyramid by cheng-li.
the class MiningModelUtil method createModelChain.
public static MiningModel createModelChain(List<? extends Model> models, Schema schema) {
if (models.size() < 1) {
throw new IllegalArgumentException();
}
Segmentation segmentation = createSegmentation(Segmentation.MultipleModelMethod.MODEL_CHAIN, models);
Model lastModel = Iterables.getLast(models);
MiningModel miningModel = new MiningModel(lastModel.getMiningFunction(), ModelUtil.createMiningSchema(schema.getLabel())).setMathContext(ModelUtil.simplifyMathContext(lastModel.getMathContext())).setSegmentation(segmentation);
return miningModel;
}
use of org.dmg.pmml.mining.Segmentation in project pyramid by cheng-li.
the class MiningModelUtil method createSegmentation.
public static Segmentation createSegmentation(Segmentation.MultipleModelMethod multipleModelMethod, List<? extends Model> models, List<? extends Number> weights) {
if ((weights != null) && (models.size() != weights.size())) {
throw new IllegalArgumentException();
}
List<Segment> segments = new ArrayList<>();
for (int i = 0; i < models.size(); i++) {
Model model = models.get(i);
Number weight = (weights != null ? weights.get(i) : null);
Segment segment = new Segment().setId(String.valueOf(i + 1)).setPredicate(new True()).setModel(model);
if (weight != null && !ValueUtil.isOne(weight)) {
segment.setWeight(ValueUtil.asDouble(weight));
}
segments.add(segment);
}
return new Segmentation(multipleModelMethod, segments);
}
use of org.dmg.pmml.mining.Segmentation in project shifu by ShifuML.
the class PMMLTranslator method build.
public PMML build(List<BasicML> basicMLs) {
if (basicMLs == null || basicMLs.size() == 0) {
throw new IllegalArgumentException("Input ml model list is empty.");
}
PMML pmml = new PMML();
// create and add header
Header header = new Header();
pmml.setHeader(header);
header.setCopyright(" Copyright [2013-2018] PayPal Software Foundation\n" + "\n" + " Licensed under the Apache License, Version 2.0 (the \"License\");\n" + " you may not use this file except in compliance with the License.\n" + " You may obtain a copy of the License at\n" + "\n" + " http://www.apache.org/licenses/LICENSE-2.0\n" + "\n" + " Unless required by applicable law or agreed to in writing, software\n" + " distributed under the License is distributed on an \"AS IS\" BASIS,\n" + " WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n" + " See the License for the specific language governing permissions and\n" + " limitations under the License.\n");
Application application = new Application();
header.setApplication(application);
application.setName("shifu");
String findContainingJar = JarManager.findContainingJar(TreeEnsemblePMMLTranslator.class);
JarFile jar = null;
try {
jar = new JarFile(findContainingJar);
final Manifest manifest = jar.getManifest();
String version = manifest.getMainAttributes().getValue("version");
application.setVersion(version);
} catch (Exception e) {
LOG.warn(e.getMessage());
} finally {
if (jar != null) {
try {
jar.close();
} catch (IOException e) {
LOG.warn(e.getMessage());
}
}
}
// create and set data dictionary for all bagging models
pmml.setDataDictionary(this.dataDictionaryCreator.build(null));
if (isOutBaggingToOne) {
MiningModel miningModel = new MiningModel();
miningModel.setMiningSchema(this.miningSchemaCreator.build(null));
miningModel.setMiningFunction(MiningFunction.fromValue("regression"));
miningModel.setTargets(((NNPmmlModelCreator) this.modelCreator).createTargets());
AbstractSpecifCreator minningModelCreator = new MiningModelPmmlCreator(this.specifCreator.getModelConfig(), this.specifCreator.getColumnConfigList());
minningModelCreator.build(null, miningModel);
Segmentation seg = new Segmentation();
miningModel.setSegmentation(seg);
seg.setMultipleModelMethod(MultipleModelMethod.AVERAGE);
List<Segment> list = seg.getSegments();
int idCount = 0;
for (BasicML basicML : basicMLs) {
// create model element
Model tmpmodel = this.modelCreator.build(basicML);
// create mining schema
tmpmodel.setMiningSchema(this.miningSchemaCreator.build(basicML));
// create variable statistical info
tmpmodel.setModelStats(this.modelStatsCreator.build(basicML));
// create variable transform
tmpmodel.setLocalTransformations(this.localTransformationsCreator.build(basicML));
this.specifCreator.build(basicML, tmpmodel, idCount);
Segment segment = new Segment();
segment.setId("Segement" + String.valueOf(idCount));
segment.setPredicate(new True());
segment.setModel(tmpmodel);
list.add(segment);
idCount++;
}
List<Model> models = pmml.getModels();
models.add(miningModel);
} else {
BasicML basicML = basicMLs.get(0);
// create model element
Model model = this.modelCreator.build(basicML);
// create mining schema
model.setMiningSchema(this.miningSchemaCreator.build(basicML));
// create variable statistical info
model.setModelStats(this.modelStatsCreator.build(basicML));
// create variable transform
model.setLocalTransformations(this.localTransformationsCreator.build(basicML));
this.specifCreator.build(basicML, model);
pmml.addModels(model);
}
return pmml;
}
use of org.dmg.pmml.mining.Segmentation in project jpmml-sparkml by jpmml.
the class ModelConverter method getLastModel.
protected org.dmg.pmml.Model getLastModel(org.dmg.pmml.Model model) {
if (model instanceof MiningModel) {
MiningModel miningModel = (MiningModel) model;
Segmentation segmentation = miningModel.getSegmentation();
MultipleModelMethod multipleModelMethod = segmentation.getMultipleModelMethod();
switch(multipleModelMethod) {
case MODEL_CHAIN:
List<Segment> segments = segmentation.getSegments();
if (segments.size() > 0) {
Segment lastSegment = segments.get(segments.size() - 1);
return lastSegment.getModel();
}
break;
default:
break;
}
}
return model;
}
use of org.dmg.pmml.mining.Segmentation in project shifu by ShifuML.
the class TreeEnsemblePmmlCreator method convert.
public MiningModel convert(IndependentTreeModel treeModel) {
MiningModel gbt = new MiningModel();
MiningSchema miningSchema = new TreeModelMiningSchemaCreator(this.modelConfig, this.columnConfigList).build(null);
gbt.setMiningSchema(miningSchema);
if (treeModel.isClassification()) {
gbt.setMiningFunction(MiningFunction.fromValue("classification"));
} else {
gbt.setMiningFunction(MiningFunction.fromValue("regression"));
}
gbt.setTargets(createTargets(this.modelConfig));
Segmentation seg = new Segmentation();
gbt.setSegmentation(seg);
seg.setMultipleModelMethod(MultipleModelMethod.fromValue("weightedAverage"));
List<Segment> list = seg.getSegments();
int idCount = 0;
// such case we only support treeModel is one element list
if (treeModel.getTrees().size() != 1) {
throw new RuntimeException("Bagging model cannot be supported in PMML generation.");
}
for (TreeNode tn : treeModel.getTrees().get(0)) {
TreeNodePmmlElementCreator tnec = new TreeNodePmmlElementCreator(this.modelConfig, this.columnConfigList, treeModel);
org.dmg.pmml.tree.Node root = tnec.convert(tn.getNode());
TreeModelPmmlElementCreator tmec = new TreeModelPmmlElementCreator(this.modelConfig, this.columnConfigList);
org.dmg.pmml.tree.TreeModel tm = tmec.convert(treeModel, root);
tm.setModelName(String.valueOf(idCount));
Segment segment = new Segment();
if (treeModel.isGBDT()) {
segment.setWeight(treeModel.getWeights().get(0).get(idCount) * treeModel.getTrees().size());
} else {
segment.setWeight(treeModel.getWeights().get(0).get(idCount));
}
segment.setId("Segement" + String.valueOf(idCount));
idCount++;
segment.setPredicate(new True());
segment.setModel(tm);
list.add(segment);
}
return gbt;
}
Aggregations