use of org.dmg.pmml.MiningField in project shifu by ShifuML.
the class PMMLAdapterCommonUtil method getDicFieldIDViaType.
/**
* Based on the usage type, get the column indexes for corresponding fields
* in the input data set
*
* @param pmml
* the pmml model
* @param type
* the type
* @return dic fields
*/
public static int[] getDicFieldIDViaType(PMML pmml, UsageType type) {
List<Integer> activeFields = new ArrayList<Integer>();
HashMap<String, Integer> dMap = new HashMap<String, Integer>();
int index = 0;
for (DataField dField : pmml.getDataDictionary().getDataFields()) dMap.put(dField.getName().getValue(), index++);
for (MiningField mField : pmml.getModels().get(0).getMiningSchema().getMiningFields()) {
if (mField.getUsageType() == type)
activeFields.add(dMap.get(mField.getName().getValue()));
}
return Ints.toArray(activeFields);
}
use of org.dmg.pmml.MiningField in project shifu by ShifuML.
the class MiningSchemaCreator method createMiningField.
private MiningField createMiningField(String name, OpType opType, UsageType fieldUsageType) {
MiningField miningField = new MiningField();
miningField.setName(FieldName.create(name));
miningField.setOpType(opType);
miningField.setUsageType(fieldUsageType);
miningField.setInvalidValueTreatment(InvalidValueTreatmentMethod.AS_MISSING);
return miningField;
}
use of org.dmg.pmml.MiningField in project jpmml-sparkml by jpmml.
the class ConverterUtil method toPMML.
public static PMML toPMML(StructType schema, PipelineModel pipelineModel) {
checkVersion();
SparkMLEncoder encoder = new SparkMLEncoder(schema);
List<org.dmg.pmml.Model> models = new ArrayList<>();
Iterable<Transformer> transformers = getTransformers(pipelineModel);
for (Transformer transformer : transformers) {
TransformerConverter<?> converter = ConverterUtil.createConverter(transformer);
if (converter instanceof FeatureConverter) {
FeatureConverter<?> featureConverter = (FeatureConverter<?>) converter;
featureConverter.registerFeatures(encoder);
} else if (converter instanceof ModelConverter) {
ModelConverter<?> modelConverter = (ModelConverter<?>) converter;
org.dmg.pmml.Model model = modelConverter.registerModel(encoder);
models.add(model);
} else {
throw new IllegalArgumentException("Expected a " + FeatureConverter.class.getName() + " or " + ModelConverter.class.getName() + " instance, got " + converter);
}
}
org.dmg.pmml.Model rootModel;
if (models.size() == 1) {
rootModel = Iterables.getOnlyElement(models);
} else if (models.size() > 1) {
List<MiningField> targetMiningFields = new ArrayList<>();
for (org.dmg.pmml.Model model : models) {
MiningSchema miningSchema = model.getMiningSchema();
List<MiningField> miningFields = miningSchema.getMiningFields();
for (MiningField miningField : miningFields) {
MiningField.UsageType usageType = miningField.getUsageType();
switch(usageType) {
case PREDICTED:
case TARGET:
targetMiningFields.add(miningField);
break;
default:
break;
}
}
}
MiningSchema miningSchema = new MiningSchema(targetMiningFields);
MiningModel miningModel = MiningModelUtil.createModelChain(models, new Schema(null, Collections.<Feature>emptyList())).setMiningSchema(miningSchema);
rootModel = miningModel;
} else {
throw new IllegalArgumentException("Expected a pipeline with one or more models, got a pipeline with zero models");
}
PMML pmml = encoder.encodePMML(rootModel);
return pmml;
}
use of org.dmg.pmml.MiningField in project shifu by ShifuML.
the class PMMLLRModelBuilder method adaptMLModelToPMML.
public RegressionModel adaptMLModelToPMML(ml.shifu.shifu.core.LR lr, RegressionModel pmmlModel) {
pmmlModel.setNormalizationMethod(NormalizationMethod.LOGIT);
pmmlModel.setMiningFunction(MiningFunction.REGRESSION);
RegressionTable table = new RegressionTable();
table.setIntercept(lr.getBias());
LocalTransformations lt = pmmlModel.getLocalTransformations();
List<DerivedField> df = lt.getDerivedFields();
HashMap<FieldName, FieldName> miningTransformMap = new HashMap<FieldName, FieldName>();
for (DerivedField dField : df) {
// Apply z-scale normalization on numerical variables
if (dField.getExpression() instanceof NormContinuous) {
miningTransformMap.put(((NormContinuous) dField.getExpression()).getField(), dField.getName());
} else // Apply bin map on categorical variables
if (dField.getExpression() instanceof MapValues) {
miningTransformMap.put(((MapValues) dField.getExpression()).getFieldColumnPairs().get(0).getField(), dField.getName());
} else if (dField.getExpression() instanceof Discretize) {
miningTransformMap.put(((Discretize) dField.getExpression()).getField(), dField.getName());
}
}
List<MiningField> miningList = pmmlModel.getMiningSchema().getMiningFields();
int index = 0;
for (int i = 0; i < miningList.size(); i++) {
MiningField mField = miningList.get(i);
if (mField.getUsageType() != UsageType.ACTIVE)
continue;
FieldName mFieldName = mField.getName();
FieldName fName = mFieldName;
while (miningTransformMap.containsKey(fName)) {
fName = miningTransformMap.get(fName);
}
NumericPredictor np = new NumericPredictor();
np.setName(fName);
np.setCoefficient(lr.getWeights()[index++]);
table.addNumericPredictors(np);
}
pmmlModel.addRegressionTables(table);
return pmmlModel;
}
use of org.dmg.pmml.MiningField in project shifu by ShifuML.
the class TreeModelMiningSchemaCreator method build.
@Override
public MiningSchema build(BasicML basicML) {
MiningSchema miningSchema = new MiningSchema();
for (ColumnConfig columnConfig : columnConfigList) {
if (columnConfig.isFinalSelect() || columnConfig.isTarget()) {
MiningField miningField = new MiningField();
// TODO, how to support segment variable in tree model, here should be changed
miningField.setName(FieldName.create(NormalUtils.getSimpleColumnName(columnConfig.getColumnName())));
miningField.setOpType(getOptype(columnConfig));
if (columnConfig.isNumerical()) {
miningField.setMissingValueReplacement(String.valueOf(columnConfig.getColumnStats().getMean()));
} else {
miningField.setMissingValueReplacement("");
}
if (columnConfig.isFinalSelect()) {
miningField.setUsageType(UsageType.ACTIVE);
} else if (columnConfig.isTarget()) {
miningField.setUsageType(UsageType.TARGET);
}
miningSchema.addMiningFields(miningField);
}
}
return miningSchema;
}
Aggregations