use of org.dmg.pmml.DerivedField in project shifu by ShifuML.
the class PMMLAdapterCommonUtil method getRegressionTable.
/**
* Generate Regression Table based on the weight list, intercept and partial
* PMML model
*
* @param weights
* weight list for the Regression Table
* @param intercept
* the intercept
* @param pmmlModel
* partial PMMl model
* @return regression model instance
*/
public static RegressionModel getRegressionTable(final double[] weights, final double intercept, RegressionModel pmmlModel) {
RegressionTable table = new RegressionTable();
MiningSchema schema = pmmlModel.getMiningSchema();
// TODO may not need target field in LRModel
pmmlModel.setMiningFunction(MiningFunction.REGRESSION);
pmmlModel.setNormalizationMethod(NormalizationMethod.LOGIT);
List<String> outputFields = getSchemaFieldViaUsageType(schema, UsageType.TARGET);
// TODO only one outputField, what if we have more than one outputField
pmmlModel.setTargetFieldName(new FieldName(outputFields.get(0)));
table.setTargetCategory(outputFields.get(0));
List<String> activeFields = getSchemaFieldViaUsageType(schema, UsageType.ACTIVE);
int index = 0;
for (DerivedField dField : pmmlModel.getLocalTransformations().getDerivedFields()) {
Expression expression = dField.getExpression();
if (expression instanceof NormContinuous) {
NormContinuous norm = (NormContinuous) expression;
if (activeFields.contains(norm.getField().getValue()))
table.addNumericPredictors(new NumericPredictor(dField.getName(), weights[index++]));
}
}
pmmlModel.addRegressionTables(table);
return pmmlModel;
}
use of org.dmg.pmml.DerivedField in project shifu by ShifuML.
the class PMMLLRModelBuilder method adaptMLModelToPMML.
public RegressionModel adaptMLModelToPMML(ml.shifu.shifu.core.LR lr, RegressionModel pmmlModel) {
pmmlModel.setNormalizationMethod(NormalizationMethod.LOGIT);
pmmlModel.setMiningFunction(MiningFunction.REGRESSION);
RegressionTable table = new RegressionTable();
table.setIntercept(lr.getBias());
LocalTransformations lt = pmmlModel.getLocalTransformations();
List<DerivedField> df = lt.getDerivedFields();
HashMap<FieldName, FieldName> miningTransformMap = new HashMap<FieldName, FieldName>();
for (DerivedField dField : df) {
// Apply z-scale normalization on numerical variables
if (dField.getExpression() instanceof NormContinuous) {
miningTransformMap.put(((NormContinuous) dField.getExpression()).getField(), dField.getName());
} else // Apply bin map on categorical variables
if (dField.getExpression() instanceof MapValues) {
miningTransformMap.put(((MapValues) dField.getExpression()).getFieldColumnPairs().get(0).getField(), dField.getName());
} else if (dField.getExpression() instanceof Discretize) {
miningTransformMap.put(((Discretize) dField.getExpression()).getField(), dField.getName());
}
}
List<MiningField> miningList = pmmlModel.getMiningSchema().getMiningFields();
int index = 0;
for (int i = 0; i < miningList.size(); i++) {
MiningField mField = miningList.get(i);
if (mField.getUsageType() != UsageType.ACTIVE)
continue;
FieldName mFieldName = mField.getName();
FieldName fName = mFieldName;
while (miningTransformMap.containsKey(fName)) {
fName = miningTransformMap.get(fName);
}
NumericPredictor np = new NumericPredictor();
np.setName(fName);
np.setCoefficient(lr.getWeights()[index++]);
table.addNumericPredictors(np);
}
pmmlModel.addRegressionTables(table);
return pmmlModel;
}
use of org.dmg.pmml.DerivedField in project shifu by ShifuML.
the class WoeZscoreLocalTransformCreator method createCategoricalDerivedField.
/**
* Create @DerivedField for categorical variable
*
* @param config - ColumnConfig for categorical variable
* @param cutoff - cutoff for normalization
* @return DerivedField for variable
*/
@Override
protected List<DerivedField> createCategoricalDerivedField(ColumnConfig config, double cutoff, ModelNormalizeConf.NormType normType) {
List<DerivedField> derivedFields = new ArrayList<DerivedField>();
DerivedField derivedField = super.createCategoricalDerivedField(config, cutoff, ModelNormalizeConf.NormType.WOE).get(0);
derivedFields.add(derivedField);
double[] meanAndStdDev = Normalizer.calculateWoeMeanAndStdDev(config, isWeightedNorm);
// added capping logic to linearNorm
LinearNorm from = new LinearNorm().setOrig(meanAndStdDev[0] - meanAndStdDev[1] * cutoff).setNorm(-cutoff);
LinearNorm to = new LinearNorm().setOrig(meanAndStdDev[0] + meanAndStdDev[1] * cutoff).setNorm(cutoff);
NormContinuous normContinuous = new NormContinuous();
normContinuous.setField(FieldName.create(derivedField.getName().getValue()));
normContinuous.addLinearNorms(from, to);
normContinuous.setMapMissingTo(0.0);
normContinuous.setOutliers(OutlierTreatmentMethod.AS_EXTREME_VALUES);
// derived field name is consisted of FieldName and "_zscl"
derivedFields.add(new DerivedField(OpType.CONTINUOUS, DataType.DOUBLE).setName(FieldName.create(genPmmlColumnName(NormalUtils.getSimpleColumnName(config.getColumnName()), normType))).setExpression(normContinuous));
return derivedFields;
}
use of org.dmg.pmml.DerivedField in project shifu by ShifuML.
the class WoeLocalTransformCreator method createNumericalDerivedField.
/**
* Create @DerivedField for numerical variable
*
* @param config
* - ColumnConfig for numerical variable
* @param cutoff
* - cutoff of normalization
* @param normType
* - the normalization method that is used to generate DerivedField
* @return DerivedField for variable
*/
@Override
protected List<DerivedField> createNumericalDerivedField(ColumnConfig config, double cutoff, ModelNormalizeConf.NormType normType) {
List<Double> binWoeList = (normType.equals(ModelNormalizeConf.NormType.WOE) ? config.getBinCountWoe() : config.getBinWeightedWoe());
List<Double> binBoundaryList = config.getBinBoundary();
List<DiscretizeBin> discretizeBinList = new ArrayList<DiscretizeBin>();
for (int i = 0; i < binBoundaryList.size(); i++) {
DiscretizeBin discretizeBin = new DiscretizeBin();
Interval interval = new Interval();
if (i == 0) {
if (binBoundaryList.size() == 1) {
interval.setClosure(Interval.Closure.OPEN_OPEN).setLeftMargin(Double.NEGATIVE_INFINITY).setRightMargin(Double.POSITIVE_INFINITY);
} else {
interval.setClosure(Interval.Closure.OPEN_OPEN).setRightMargin(binBoundaryList.get(i + 1));
}
} else if (i == binBoundaryList.size() - 1) {
interval.setClosure(Interval.Closure.CLOSED_OPEN).setLeftMargin(binBoundaryList.get(i));
} else {
interval.setClosure(Interval.Closure.CLOSED_OPEN).setLeftMargin(binBoundaryList.get(i)).setRightMargin(binBoundaryList.get(i + 1));
}
discretizeBin.setInterval(interval).setBinValue(Double.toString(binWoeList.get(i)));
discretizeBinList.add(discretizeBin);
}
Discretize discretize = new Discretize();
discretize.setDataType(DataType.DOUBLE).setField(FieldName.create(NormalUtils.getSimpleColumnName(config, columnConfigList, segmentExpansions, datasetHeaders))).setMapMissingTo(Normalizer.normalize(config, null, cutoff, normType).get(0).toString()).setDefaultValue(Normalizer.normalize(config, null, cutoff, normType).get(0).toString()).addDiscretizeBins(discretizeBinList.toArray(new DiscretizeBin[discretizeBinList.size()]));
// derived field name is consisted of FieldName and "_zscl"
List<DerivedField> derivedFields = new ArrayList<DerivedField>();
derivedFields.add(new DerivedField(OpType.CONTINUOUS, DataType.DOUBLE).setName(FieldName.create(genPmmlColumnName(NormalUtils.getSimpleColumnName(config.getColumnName()), normType))).setExpression(discretize));
return derivedFields;
}
use of org.dmg.pmml.DerivedField in project shifu by ShifuML.
the class ZscoreLocalTransformCreator method createCategoricalDerivedField.
/**
* Create DerivedField for categorical variable
*
* @param config
* - ColumnConfig for categorical variable
* @param cutoff
* - cutoff for normalization
* @param normType
* - the normalization method that is used to generate DerivedField
* @return DerivedField for variable
*/
protected List<DerivedField> createCategoricalDerivedField(ColumnConfig config, double cutoff, ModelNormalizeConf.NormType normType) {
Document document = null;
try {
document = DocumentBuilderFactory.newInstance().newDocumentBuilder().newDocument();
} catch (ParserConfigurationException e) {
LOG.error("Fail to create document node.", e);
throw new RuntimeException("Fail to create document node.", e);
}
String defaultValue = Normalizer.normalize(config, "doesn't exist at all...by paypal", cutoff, normType).get(0).toString();
String missingValue = Normalizer.normalize(config, null, cutoff, normType).get(0).toString();
InlineTable inlineTable = new InlineTable();
for (int i = 0; i < config.getBinCategory().size(); i++) {
List<String> catVals = CommonUtils.flattenCatValGrp(config.getBinCategory().get(i));
for (String cval : catVals) {
String dval = Normalizer.normalize(config, cval, cutoff, normType).get(0).toString();
Element out = document.createElementNS(NAME_SPACE_URI, ELEMENT_OUT);
out.setTextContent(dval);
Element origin = document.createElementNS(NAME_SPACE_URI, ELEMENT_ORIGIN);
origin.setTextContent(cval);
inlineTable.addRows(new Row().addContent(origin).addContent(out));
}
}
MapValues mapValues = new MapValues("out").setDataType(DataType.DOUBLE).setDefaultValue(defaultValue).addFieldColumnPairs(new FieldColumnPair(new FieldName(NormalUtils.getSimpleColumnName(config, columnConfigList, segmentExpansions, datasetHeaders)), ELEMENT_ORIGIN)).setInlineTable(inlineTable).setMapMissingTo(missingValue);
List<DerivedField> derivedFields = new ArrayList<DerivedField>();
derivedFields.add(new DerivedField(OpType.CONTINUOUS, DataType.DOUBLE).setName(FieldName.create(genPmmlColumnName(NormalUtils.getSimpleColumnName(config.getColumnName()), normType))).setExpression(mapValues));
return derivedFields;
}
Aggregations