use of org.dmg.pmml.Model in project shifu by ShifuML.
the class PMMLConstructorFactory method produce.
public static PMMLTranslator produce(ModelConfig modelConfig, List<ColumnConfig> columnConfigList, boolean isConcise, boolean isOutBaggingToOne) {
AbstractPmmlElementCreator<Model> modelCreator = null;
AbstractSpecifCreator specifCreator = null;
if (ModelTrainConf.ALGORITHM.NN.name().equalsIgnoreCase(modelConfig.getTrain().getAlgorithm())) {
modelCreator = new NNPmmlModelCreator(modelConfig, columnConfigList, isConcise);
specifCreator = new NNSpecifCreator(modelConfig, columnConfigList);
} else if (ModelTrainConf.ALGORITHM.LR.name().equalsIgnoreCase(modelConfig.getTrain().getAlgorithm())) {
modelCreator = new RegressionPmmlModelCreator(modelConfig, columnConfigList, isConcise);
specifCreator = new RegressionSpecifCreator(modelConfig, columnConfigList);
} else if (ModelTrainConf.ALGORITHM.GBT.name().equalsIgnoreCase(modelConfig.getTrain().getAlgorithm()) || ModelTrainConf.ALGORITHM.RF.name().equalsIgnoreCase(modelConfig.getTrain().getAlgorithm())) {
TreeEnsemblePmmlCreator gbtmodelCreator = new TreeEnsemblePmmlCreator(modelConfig, columnConfigList);
AbstractPmmlElementCreator<DataDictionary> dataDictionaryCreator = new DataDictionaryCreator(modelConfig, columnConfigList);
AbstractPmmlElementCreator<MiningSchema> miningSchemaCreator = new TreeModelMiningSchemaCreator(modelConfig, columnConfigList);
return new TreeEnsemblePMMLTranslator(gbtmodelCreator, dataDictionaryCreator, miningSchemaCreator);
} else {
throw new RuntimeException("Model not supported: " + modelConfig.getTrain().getAlgorithm());
}
AbstractPmmlElementCreator<DataDictionary> dataDictionaryCreator = new DataDictionaryCreator(modelConfig, columnConfigList, isConcise);
AbstractPmmlElementCreator<MiningSchema> miningSchemaCreator = new MiningSchemaCreator(modelConfig, columnConfigList, isConcise);
AbstractPmmlElementCreator<ModelStats> modelStatsCreator = new ModelStatsCreator(modelConfig, columnConfigList, isConcise);
AbstractPmmlElementCreator<LocalTransformations> localTransformationsCreator = null;
ModelNormalizeConf.NormType normType = modelConfig.getNormalizeType();
if (normType.equals(ModelNormalizeConf.NormType.WOE) || normType.equals(ModelNormalizeConf.NormType.WEIGHT_WOE)) {
localTransformationsCreator = new WoeLocalTransformCreator(modelConfig, columnConfigList, isConcise);
} else if (normType == ModelNormalizeConf.NormType.WOE_ZSCORE || normType == ModelNormalizeConf.NormType.WOE_ZSCALE) {
localTransformationsCreator = new WoeZscoreLocalTransformCreator(modelConfig, columnConfigList, isConcise, false);
} else if (normType == ModelNormalizeConf.NormType.WEIGHT_WOE_ZSCORE || normType == ModelNormalizeConf.NormType.WEIGHT_WOE_ZSCALE) {
localTransformationsCreator = new WoeZscoreLocalTransformCreator(modelConfig, columnConfigList, isConcise, true);
} else if (normType == ModelNormalizeConf.NormType.ZSCALE_ONEHOT) {
localTransformationsCreator = new ZscoreOneHotLocalTransformCreator(modelConfig, columnConfigList, isConcise);
} else {
localTransformationsCreator = new ZscoreLocalTransformCreator(modelConfig, columnConfigList, isConcise);
}
return new PMMLTranslator(modelCreator, dataDictionaryCreator, miningSchemaCreator, modelStatsCreator, localTransformationsCreator, specifCreator, isOutBaggingToOne);
}
use of org.dmg.pmml.Model in project shifu by ShifuML.
the class NNPmmlModelCreator method build.
@Override
public Model build(BasicML basicML) {
Model model = new NeuralNetwork();
/* if ( modelConfig.isClassification() &&
ModelTrainConf.MultipleClassification.NATIVE.equals(modelConfig.getTrain().getMultiClassifyMethod())) {
model.setFunctionName(MiningFunctionType.CLASSIFICATION);
} else {*/
model.setMiningFunction(MiningFunction.REGRESSION);
/* }*/
model.setTargets(createTargets());
return model;
}
use of org.dmg.pmml.Model in project shifu by ShifuML.
the class PMMLVerifySuit method evalLRPmml.
@SuppressWarnings("unchecked")
private void evalLRPmml(String pmmlPath, String DataPath, String OutPath, String sep, String scoreName) throws Exception {
PMML pmml = PMMLUtils.loadPMML(pmmlPath);
Model m = pmml.getModels().get(0);
ModelEvaluator<?> evaluator = ModelEvaluatorFactory.newInstance().newModelEvaluator(pmml, m);
PrintWriter writer = new PrintWriter(OutPath, "UTF-8");
writer.println(scoreName);
List<Map<FieldName, FieldValue>> input = CsvUtil.load(evaluator, DataPath, sep);
for (Map<FieldName, FieldValue> maps : input) {
Map<FieldName, Double> regressionTerm = (Map<FieldName, Double>) evaluator.evaluate(maps);
writer.println(regressionTerm.get(new FieldName(NNSpecifCreator.FINAL_RESULT)).intValue());
}
IOUtils.closeQuietly(writer);
}
use of org.dmg.pmml.Model in project jpmml-r by jpmml.
the class ModelConverter method encode.
public Model encode(Schema schema) {
Model model = encodeModel(schema);
if (this instanceof HasFeatureImportances) {
HasFeatureImportances hasFeatureImportances = (HasFeatureImportances) this;
FeatureImportanceMap featureImportances = hasFeatureImportances.getFeatureImportances(schema);
if (featureImportances != null && !featureImportances.isEmpty()) {
ModelEncoder encoder = (ModelEncoder) schema.getEncoder();
Collection<Map.Entry<Feature, Number>> entries = featureImportances.entrySet();
for (Map.Entry<Feature, Number> entry : entries) {
encoder.addFeatureImportance(model, entry.getKey(), entry.getValue());
}
}
}
return model;
}
use of org.dmg.pmml.Model in project jpmml-r by jpmml.
the class CaretEnsembleConverter method encodePMML.
@Override
public PMML encodePMML(RExpEncoder encoder) {
RGenericVector caretEnsemble = getObject();
RGenericVector models = caretEnsemble.getGenericElement("models");
RGenericVector ensModel = caretEnsemble.getGenericElement("ens_model");
RStringVector modelNames = models.names();
List<Model> segmentationModels = new ArrayList<>();
Function<Schema, Schema> segmentSchemaFunction = new Function<Schema, Schema>() {
@Override
public Schema apply(Schema schema) {
Label label = schema.getLabel();
if (label instanceof ContinuousLabel) {
return schema.toAnonymousSchema();
} else // XXX: Ideally, the categorical target field should also be anonymized
if (label instanceof CategoricalLabel) {
return schema;
} else {
throw new IllegalArgumentException();
}
}
};
for (int i = 0; i < models.size(); i++) {
RGenericVector model = models.getGenericValue(i);
Conversion conversion = encodeTrainModel(model, segmentSchemaFunction);
RExpEncoder segmentEncoder = conversion.getEncoder();
encoder.addFields(segmentEncoder);
Schema segmentSchema = conversion.getSchema();
Model segmentModel = conversion.getModel();
String name = modelNames.getValue(i);
OutputField outputField;
MiningFunction miningFunction = segmentModel.requireMiningFunction();
switch(miningFunction) {
case REGRESSION:
{
outputField = ModelUtil.createPredictedField(name, OpType.CONTINUOUS, DataType.DOUBLE).setFinalResult(Boolean.FALSE);
}
break;
case CLASSIFICATION:
{
CategoricalLabel categoricalLabel = (CategoricalLabel) segmentSchema.getLabel();
SchemaUtil.checkSize(2, categoricalLabel);
outputField = ModelUtil.createProbabilityField(name, DataType.DOUBLE, categoricalLabel.getValue(1)).setFinalResult(Boolean.FALSE);
}
break;
default:
throw new IllegalArgumentException();
}
Output output = new Output().addOutputFields(outputField);
segmentModel.setOutput(output);
segmentationModels.add(segmentModel);
}
Conversion conversion = encodeTrainModel(ensModel, null);
Model model = conversion.getModel();
segmentationModels.add(model);
MiningModel miningModel = MiningModelUtil.createModelChain(segmentationModels);
PMML pmml = encoder.encodePMML(miningModel);
return pmml;
}
Aggregations