Search in sources :

Example 1 with MiningField

use of org.dmg.pmml.MiningField in project shifu by ShifuML.

the class PMMLAdapterCommonUtil method getDicFieldIDViaType.

/**
 * Based on the usage type, get the column indexes for corresponding fields
 * in the input data set
 *
 * @param pmml
 *            the pmml model
 * @param type
 *            the type
 * @return dic fields
 */
public static int[] getDicFieldIDViaType(PMML pmml, UsageType type) {
    List<Integer> activeFields = new ArrayList<Integer>();
    HashMap<String, Integer> dMap = new HashMap<String, Integer>();
    int index = 0;
    for (DataField dField : pmml.getDataDictionary().getDataFields()) dMap.put(dField.getName().getValue(), index++);
    for (MiningField mField : pmml.getModels().get(0).getMiningSchema().getMiningFields()) {
        if (mField.getUsageType() == type)
            activeFields.add(dMap.get(mField.getName().getValue()));
    }
    return Ints.toArray(activeFields);
}
Also used : MiningField(org.dmg.pmml.MiningField) DataField(org.dmg.pmml.DataField) HashMap(java.util.HashMap) ArrayList(java.util.ArrayList)

Example 2 with MiningField

use of org.dmg.pmml.MiningField in project shifu by ShifuML.

the class MiningSchemaCreator method createMiningField.

private MiningField createMiningField(String name, OpType opType, UsageType fieldUsageType) {
    MiningField miningField = new MiningField();
    miningField.setName(FieldName.create(name));
    miningField.setOpType(opType);
    miningField.setUsageType(fieldUsageType);
    miningField.setInvalidValueTreatment(InvalidValueTreatmentMethod.AS_MISSING);
    return miningField;
}
Also used : MiningField(org.dmg.pmml.MiningField)

Example 3 with MiningField

use of org.dmg.pmml.MiningField in project jpmml-sparkml by jpmml.

the class ConverterUtil method toPMML.

public static PMML toPMML(StructType schema, PipelineModel pipelineModel) {
    checkVersion();
    SparkMLEncoder encoder = new SparkMLEncoder(schema);
    List<org.dmg.pmml.Model> models = new ArrayList<>();
    Iterable<Transformer> transformers = getTransformers(pipelineModel);
    for (Transformer transformer : transformers) {
        TransformerConverter<?> converter = ConverterUtil.createConverter(transformer);
        if (converter instanceof FeatureConverter) {
            FeatureConverter<?> featureConverter = (FeatureConverter<?>) converter;
            featureConverter.registerFeatures(encoder);
        } else if (converter instanceof ModelConverter) {
            ModelConverter<?> modelConverter = (ModelConverter<?>) converter;
            org.dmg.pmml.Model model = modelConverter.registerModel(encoder);
            models.add(model);
        } else {
            throw new IllegalArgumentException("Expected a " + FeatureConverter.class.getName() + " or " + ModelConverter.class.getName() + " instance, got " + converter);
        }
    }
    org.dmg.pmml.Model rootModel;
    if (models.size() == 1) {
        rootModel = Iterables.getOnlyElement(models);
    } else if (models.size() > 1) {
        List<MiningField> targetMiningFields = new ArrayList<>();
        for (org.dmg.pmml.Model model : models) {
            MiningSchema miningSchema = model.getMiningSchema();
            List<MiningField> miningFields = miningSchema.getMiningFields();
            for (MiningField miningField : miningFields) {
                MiningField.UsageType usageType = miningField.getUsageType();
                switch(usageType) {
                    case PREDICTED:
                    case TARGET:
                        targetMiningFields.add(miningField);
                        break;
                    default:
                        break;
                }
            }
        }
        MiningSchema miningSchema = new MiningSchema(targetMiningFields);
        MiningModel miningModel = MiningModelUtil.createModelChain(models, new Schema(null, Collections.<Feature>emptyList())).setMiningSchema(miningSchema);
        rootModel = miningModel;
    } else {
        throw new IllegalArgumentException("Expected a pipeline with one or more models, got a pipeline with zero models");
    }
    PMML pmml = encoder.encodePMML(rootModel);
    return pmml;
}
Also used : MiningField(org.dmg.pmml.MiningField) Transformer(org.apache.spark.ml.Transformer) MiningSchema(org.dmg.pmml.MiningSchema) Schema(org.jpmml.converter.Schema) ArrayList(java.util.ArrayList) Feature(org.jpmml.converter.Feature) MiningSchema(org.dmg.pmml.MiningSchema) MiningModel(org.dmg.pmml.mining.MiningModel) MiningModel(org.dmg.pmml.mining.MiningModel) PipelineModel(org.apache.spark.ml.PipelineModel) TrainValidationSplitModel(org.apache.spark.ml.tuning.TrainValidationSplitModel) CrossValidatorModel(org.apache.spark.ml.tuning.CrossValidatorModel) PMML(org.dmg.pmml.PMML) ArrayList(java.util.ArrayList) List(java.util.List)

Example 4 with MiningField

use of org.dmg.pmml.MiningField in project shifu by ShifuML.

the class PMMLLRModelBuilder method adaptMLModelToPMML.

public RegressionModel adaptMLModelToPMML(ml.shifu.shifu.core.LR lr, RegressionModel pmmlModel) {
    pmmlModel.setNormalizationMethod(NormalizationMethod.LOGIT);
    pmmlModel.setMiningFunction(MiningFunction.REGRESSION);
    RegressionTable table = new RegressionTable();
    table.setIntercept(lr.getBias());
    LocalTransformations lt = pmmlModel.getLocalTransformations();
    List<DerivedField> df = lt.getDerivedFields();
    HashMap<FieldName, FieldName> miningTransformMap = new HashMap<FieldName, FieldName>();
    for (DerivedField dField : df) {
        // Apply z-scale normalization on numerical variables
        if (dField.getExpression() instanceof NormContinuous) {
            miningTransformMap.put(((NormContinuous) dField.getExpression()).getField(), dField.getName());
        } else // Apply bin map on categorical variables
        if (dField.getExpression() instanceof MapValues) {
            miningTransformMap.put(((MapValues) dField.getExpression()).getFieldColumnPairs().get(0).getField(), dField.getName());
        } else if (dField.getExpression() instanceof Discretize) {
            miningTransformMap.put(((Discretize) dField.getExpression()).getField(), dField.getName());
        }
    }
    List<MiningField> miningList = pmmlModel.getMiningSchema().getMiningFields();
    int index = 0;
    for (int i = 0; i < miningList.size(); i++) {
        MiningField mField = miningList.get(i);
        if (mField.getUsageType() != UsageType.ACTIVE)
            continue;
        FieldName mFieldName = mField.getName();
        FieldName fName = mFieldName;
        while (miningTransformMap.containsKey(fName)) {
            fName = miningTransformMap.get(fName);
        }
        NumericPredictor np = new NumericPredictor();
        np.setName(fName);
        np.setCoefficient(lr.getWeights()[index++]);
        table.addNumericPredictors(np);
    }
    pmmlModel.addRegressionTables(table);
    return pmmlModel;
}
Also used : NormContinuous(org.dmg.pmml.NormContinuous) MiningField(org.dmg.pmml.MiningField) HashMap(java.util.HashMap) NumericPredictor(org.dmg.pmml.regression.NumericPredictor) RegressionTable(org.dmg.pmml.regression.RegressionTable) LocalTransformations(org.dmg.pmml.LocalTransformations) MapValues(org.dmg.pmml.MapValues) Discretize(org.dmg.pmml.Discretize) DerivedField(org.dmg.pmml.DerivedField) FieldName(org.dmg.pmml.FieldName)

Example 5 with MiningField

use of org.dmg.pmml.MiningField in project shifu by ShifuML.

the class TreeModelMiningSchemaCreator method build.

@Override
public MiningSchema build(BasicML basicML) {
    MiningSchema miningSchema = new MiningSchema();
    for (ColumnConfig columnConfig : columnConfigList) {
        if (columnConfig.isFinalSelect() || columnConfig.isTarget()) {
            MiningField miningField = new MiningField();
            // TODO, how to support segment variable in tree model, here should be changed
            miningField.setName(FieldName.create(NormalUtils.getSimpleColumnName(columnConfig.getColumnName())));
            miningField.setOpType(getOptype(columnConfig));
            if (columnConfig.isNumerical()) {
                miningField.setMissingValueReplacement(String.valueOf(columnConfig.getColumnStats().getMean()));
            } else {
                miningField.setMissingValueReplacement("");
            }
            if (columnConfig.isFinalSelect()) {
                miningField.setUsageType(UsageType.ACTIVE);
            } else if (columnConfig.isTarget()) {
                miningField.setUsageType(UsageType.TARGET);
            }
            miningSchema.addMiningFields(miningField);
        }
    }
    return miningSchema;
}
Also used : MiningField(org.dmg.pmml.MiningField) MiningSchema(org.dmg.pmml.MiningSchema) ColumnConfig(ml.shifu.shifu.container.obj.ColumnConfig)

Aggregations

MiningField (org.dmg.pmml.MiningField)8 ArrayList (java.util.ArrayList)4 HashMap (java.util.HashMap)3 List (java.util.List)3 FieldName (org.dmg.pmml.FieldName)3 MiningSchema (org.dmg.pmml.MiningSchema)3 ColumnConfig (ml.shifu.shifu.container.obj.ColumnConfig)2 DerivedField (org.dmg.pmml.DerivedField)2 Discretize (org.dmg.pmml.Discretize)2 MapValues (org.dmg.pmml.MapValues)2 NormContinuous (org.dmg.pmml.NormContinuous)2 BasicFloatNetwork (ml.shifu.shifu.core.dtrain.dataset.BasicFloatNetwork)1 PipelineModel (org.apache.spark.ml.PipelineModel)1 Transformer (org.apache.spark.ml.Transformer)1 CrossValidatorModel (org.apache.spark.ml.tuning.CrossValidatorModel)1 TrainValidationSplitModel (org.apache.spark.ml.tuning.TrainValidationSplitModel)1 DataField (org.dmg.pmml.DataField)1 FieldRef (org.dmg.pmml.FieldRef)1 LocalTransformations (org.dmg.pmml.LocalTransformations)1 PMML (org.dmg.pmml.PMML)1