Search in sources :

Example 36 with DerivedField

use of org.dmg.pmml.DerivedField in project shifu by ShifuML.

the class PMMLLRModelBuilder method adaptMLModelToPMML.

public RegressionModel adaptMLModelToPMML(ml.shifu.shifu.core.LR lr, RegressionModel pmmlModel) {
    pmmlModel.setNormalizationMethod(NormalizationMethod.LOGIT);
    pmmlModel.setMiningFunction(MiningFunction.REGRESSION);
    RegressionTable table = new RegressionTable();
    table.setIntercept(lr.getBias());
    LocalTransformations lt = pmmlModel.getLocalTransformations();
    List<DerivedField> df = lt.getDerivedFields();
    HashMap<FieldName, FieldName> miningTransformMap = new HashMap<FieldName, FieldName>();
    for (DerivedField dField : df) {
        // Apply z-scale normalization on numerical variables
        if (dField.getExpression() instanceof NormContinuous) {
            miningTransformMap.put(((NormContinuous) dField.getExpression()).getField(), dField.getName());
        } else // Apply bin map on categorical variables
        if (dField.getExpression() instanceof MapValues) {
            miningTransformMap.put(((MapValues) dField.getExpression()).getFieldColumnPairs().get(0).getField(), dField.getName());
        } else if (dField.getExpression() instanceof Discretize) {
            miningTransformMap.put(((Discretize) dField.getExpression()).getField(), dField.getName());
        }
    }
    List<MiningField> miningList = pmmlModel.getMiningSchema().getMiningFields();
    int index = 0;
    for (int i = 0; i < miningList.size(); i++) {
        MiningField mField = miningList.get(i);
        if (mField.getUsageType() != UsageType.ACTIVE)
            continue;
        FieldName mFieldName = mField.getName();
        FieldName fName = mFieldName;
        while (miningTransformMap.containsKey(fName)) {
            fName = miningTransformMap.get(fName);
        }
        NumericPredictor np = new NumericPredictor();
        np.setName(fName);
        np.setCoefficient(lr.getWeights()[index++]);
        table.addNumericPredictors(np);
    }
    pmmlModel.addRegressionTables(table);
    return pmmlModel;
}
Also used : NormContinuous(org.dmg.pmml.NormContinuous) MiningField(org.dmg.pmml.MiningField) HashMap(java.util.HashMap) NumericPredictor(org.dmg.pmml.regression.NumericPredictor) RegressionTable(org.dmg.pmml.regression.RegressionTable) LocalTransformations(org.dmg.pmml.LocalTransformations) MapValues(org.dmg.pmml.MapValues) Discretize(org.dmg.pmml.Discretize) DerivedField(org.dmg.pmml.DerivedField) FieldName(org.dmg.pmml.FieldName)

Example 37 with DerivedField

use of org.dmg.pmml.DerivedField in project shifu by ShifuML.

the class WoeZscoreLocalTransformCreator method createCategoricalDerivedField.

/**
 * Create @DerivedField for categorical variable
 *
 * @param config - ColumnConfig for categorical variable
 * @param cutoff - cutoff for normalization
 * @return DerivedField for variable
 */
@Override
protected List<DerivedField> createCategoricalDerivedField(ColumnConfig config, double cutoff, ModelNormalizeConf.NormType normType) {
    List<DerivedField> derivedFields = new ArrayList<DerivedField>();
    DerivedField derivedField = super.createCategoricalDerivedField(config, cutoff, ModelNormalizeConf.NormType.WOE).get(0);
    derivedFields.add(derivedField);
    double[] meanAndStdDev = Normalizer.calculateWoeMeanAndStdDev(config, isWeightedNorm);
    // added capping logic to linearNorm
    LinearNorm from = new LinearNorm().setOrig(meanAndStdDev[0] - meanAndStdDev[1] * cutoff).setNorm(-cutoff);
    LinearNorm to = new LinearNorm().setOrig(meanAndStdDev[0] + meanAndStdDev[1] * cutoff).setNorm(cutoff);
    NormContinuous normContinuous = new NormContinuous();
    normContinuous.setField(FieldName.create(derivedField.getName().getValue()));
    normContinuous.addLinearNorms(from, to);
    normContinuous.setMapMissingTo(0.0);
    normContinuous.setOutliers(OutlierTreatmentMethod.AS_EXTREME_VALUES);
    // derived field name is consisted of FieldName and "_zscl"
    derivedFields.add(new DerivedField(OpType.CONTINUOUS, DataType.DOUBLE).setName(FieldName.create(genPmmlColumnName(NormalUtils.getSimpleColumnName(config.getColumnName()), normType))).setExpression(normContinuous));
    return derivedFields;
}
Also used : NormContinuous(org.dmg.pmml.NormContinuous) LinearNorm(org.dmg.pmml.LinearNorm) ArrayList(java.util.ArrayList) DerivedField(org.dmg.pmml.DerivedField)

Example 38 with DerivedField

use of org.dmg.pmml.DerivedField in project shifu by ShifuML.

the class WoeLocalTransformCreator method createNumericalDerivedField.

/**
 * Create @DerivedField for numerical variable
 *
 * @param config
 *            - ColumnConfig for numerical variable
 * @param cutoff
 *            - cutoff of normalization
 * @param normType
 *            - the normalization method that is used to generate DerivedField
 * @return DerivedField for variable
 */
@Override
protected List<DerivedField> createNumericalDerivedField(ColumnConfig config, double cutoff, ModelNormalizeConf.NormType normType) {
    List<Double> binWoeList = (normType.equals(ModelNormalizeConf.NormType.WOE) ? config.getBinCountWoe() : config.getBinWeightedWoe());
    List<Double> binBoundaryList = config.getBinBoundary();
    List<DiscretizeBin> discretizeBinList = new ArrayList<DiscretizeBin>();
    for (int i = 0; i < binBoundaryList.size(); i++) {
        DiscretizeBin discretizeBin = new DiscretizeBin();
        Interval interval = new Interval();
        if (i == 0) {
            if (binBoundaryList.size() == 1) {
                interval.setClosure(Interval.Closure.OPEN_OPEN).setLeftMargin(Double.NEGATIVE_INFINITY).setRightMargin(Double.POSITIVE_INFINITY);
            } else {
                interval.setClosure(Interval.Closure.OPEN_OPEN).setRightMargin(binBoundaryList.get(i + 1));
            }
        } else if (i == binBoundaryList.size() - 1) {
            interval.setClosure(Interval.Closure.CLOSED_OPEN).setLeftMargin(binBoundaryList.get(i));
        } else {
            interval.setClosure(Interval.Closure.CLOSED_OPEN).setLeftMargin(binBoundaryList.get(i)).setRightMargin(binBoundaryList.get(i + 1));
        }
        discretizeBin.setInterval(interval).setBinValue(Double.toString(binWoeList.get(i)));
        discretizeBinList.add(discretizeBin);
    }
    Discretize discretize = new Discretize();
    discretize.setDataType(DataType.DOUBLE).setField(FieldName.create(NormalUtils.getSimpleColumnName(config, columnConfigList, segmentExpansions, datasetHeaders))).setMapMissingTo(Normalizer.normalize(config, null, cutoff, normType).get(0).toString()).setDefaultValue(Normalizer.normalize(config, null, cutoff, normType).get(0).toString()).addDiscretizeBins(discretizeBinList.toArray(new DiscretizeBin[discretizeBinList.size()]));
    // derived field name is consisted of FieldName and "_zscl"
    List<DerivedField> derivedFields = new ArrayList<DerivedField>();
    derivedFields.add(new DerivedField(OpType.CONTINUOUS, DataType.DOUBLE).setName(FieldName.create(genPmmlColumnName(NormalUtils.getSimpleColumnName(config.getColumnName()), normType))).setExpression(discretize));
    return derivedFields;
}
Also used : ArrayList(java.util.ArrayList) Discretize(org.dmg.pmml.Discretize) DiscretizeBin(org.dmg.pmml.DiscretizeBin) DerivedField(org.dmg.pmml.DerivedField) Interval(org.dmg.pmml.Interval)

Example 39 with DerivedField

use of org.dmg.pmml.DerivedField in project shifu by ShifuML.

the class ZscoreLocalTransformCreator method createCategoricalDerivedField.

/**
 * Create DerivedField for categorical variable
 *
 * @param config
 *            - ColumnConfig for categorical variable
 * @param cutoff
 *            - cutoff for normalization
 * @param normType
 *            - the normalization method that is used to generate DerivedField
 * @return DerivedField for variable
 */
protected List<DerivedField> createCategoricalDerivedField(ColumnConfig config, double cutoff, ModelNormalizeConf.NormType normType) {
    Document document = null;
    try {
        document = DocumentBuilderFactory.newInstance().newDocumentBuilder().newDocument();
    } catch (ParserConfigurationException e) {
        LOG.error("Fail to create document node.", e);
        throw new RuntimeException("Fail to create document node.", e);
    }
    String defaultValue = Normalizer.normalize(config, "doesn't exist at all...by paypal", cutoff, normType).get(0).toString();
    String missingValue = Normalizer.normalize(config, null, cutoff, normType).get(0).toString();
    InlineTable inlineTable = new InlineTable();
    for (int i = 0; i < config.getBinCategory().size(); i++) {
        List<String> catVals = CommonUtils.flattenCatValGrp(config.getBinCategory().get(i));
        for (String cval : catVals) {
            String dval = Normalizer.normalize(config, cval, cutoff, normType).get(0).toString();
            Element out = document.createElementNS(NAME_SPACE_URI, ELEMENT_OUT);
            out.setTextContent(dval);
            Element origin = document.createElementNS(NAME_SPACE_URI, ELEMENT_ORIGIN);
            origin.setTextContent(cval);
            inlineTable.addRows(new Row().addContent(origin).addContent(out));
        }
    }
    MapValues mapValues = new MapValues("out").setDataType(DataType.DOUBLE).setDefaultValue(defaultValue).addFieldColumnPairs(new FieldColumnPair(new FieldName(NormalUtils.getSimpleColumnName(config, columnConfigList, segmentExpansions, datasetHeaders)), ELEMENT_ORIGIN)).setInlineTable(inlineTable).setMapMissingTo(missingValue);
    List<DerivedField> derivedFields = new ArrayList<DerivedField>();
    derivedFields.add(new DerivedField(OpType.CONTINUOUS, DataType.DOUBLE).setName(FieldName.create(genPmmlColumnName(NormalUtils.getSimpleColumnName(config.getColumnName()), normType))).setExpression(mapValues));
    return derivedFields;
}
Also used : InlineTable(org.dmg.pmml.InlineTable) Element(org.w3c.dom.Element) ArrayList(java.util.ArrayList) FieldColumnPair(org.dmg.pmml.FieldColumnPair) Document(org.w3c.dom.Document) MapValues(org.dmg.pmml.MapValues) ParserConfigurationException(javax.xml.parsers.ParserConfigurationException) Row(org.dmg.pmml.Row) FieldName(org.dmg.pmml.FieldName) DerivedField(org.dmg.pmml.DerivedField)

Example 40 with DerivedField

use of org.dmg.pmml.DerivedField in project shifu by ShifuML.

the class ZscoreLocalTransformCreator method build.

@Override
public LocalTransformations build(BasicML basicML) {
    LocalTransformations localTransformations = new LocalTransformations();
    if (basicML instanceof BasicFloatNetwork) {
        BasicFloatNetwork bfn = (BasicFloatNetwork) basicML;
        Set<Integer> featureSet = bfn.getFeatureSet();
        for (ColumnConfig config : columnConfigList) {
            if (config.isFinalSelect() && (CollectionUtils.isEmpty(featureSet) || featureSet.contains(config.getColumnNum()))) {
                double cutoff = modelConfig.getNormalizeStdDevCutOff();
                List<DerivedField> deriviedFields = config.isCategorical() ? createCategoricalDerivedField(config, cutoff, modelConfig.getNormalizeType()) : createNumericalDerivedField(config, cutoff, modelConfig.getNormalizeType());
                localTransformations.addDerivedFields(deriviedFields.toArray(new DerivedField[deriviedFields.size()]));
            }
        }
    } else {
        for (ColumnConfig config : columnConfigList) {
            if (config.isFinalSelect()) {
                double cutoff = modelConfig.getNormalizeStdDevCutOff();
                List<DerivedField> deriviedFields = config.isCategorical() ? createCategoricalDerivedField(config, cutoff, modelConfig.getNormalizeType()) : createNumericalDerivedField(config, cutoff, modelConfig.getNormalizeType());
                localTransformations.addDerivedFields(deriviedFields.toArray(new DerivedField[deriviedFields.size()]));
            }
        }
    }
    return localTransformations;
}
Also used : LocalTransformations(org.dmg.pmml.LocalTransformations) ColumnConfig(ml.shifu.shifu.container.obj.ColumnConfig) BasicFloatNetwork(ml.shifu.shifu.core.dtrain.dataset.BasicFloatNetwork) DerivedField(org.dmg.pmml.DerivedField)

Aggregations

DerivedField (org.dmg.pmml.DerivedField)48 ArrayList (java.util.ArrayList)17 ContinuousFeature (org.jpmml.converter.ContinuousFeature)16 Feature (org.jpmml.converter.Feature)16 Apply (org.dmg.pmml.Apply)10 Expression (org.dmg.pmml.Expression)10 FieldName (org.dmg.pmml.FieldName)9 Test (org.junit.Test)8 KiePMMLDerivedField (org.kie.pmml.commons.transformations.KiePMMLDerivedField)8 Constant (org.dmg.pmml.Constant)7 DataField (org.dmg.pmml.DataField)7 NormContinuous (org.dmg.pmml.NormContinuous)6 BlockStmt (com.github.javaparser.ast.stmt.BlockStmt)5 List (java.util.List)5 CategoricalFeature (org.jpmml.converter.CategoricalFeature)5 Discretize (org.dmg.pmml.Discretize)4 FieldRef (org.dmg.pmml.FieldRef)4 MapValues (org.dmg.pmml.MapValues)4 Statement (com.github.javaparser.ast.stmt.Statement)3 HashMap (java.util.HashMap)3