use of org.dmg.pmml.regression.NumericPredictor in project drools by kiegroup.
the class PMMLModelTestUtils method getRegressionTable.
public static RegressionTable getRegressionTable(List<CategoricalPredictor> categoricalPredictors, List<NumericPredictor> numericPredictors, List<PredictorTerm> predictorTerms, double intercept, Object targetCategory) {
RegressionTable toReturn = new RegressionTable();
toReturn.setIntercept(intercept);
toReturn.setTargetCategory(targetCategory);
toReturn.addCategoricalPredictors(categoricalPredictors.toArray(new CategoricalPredictor[0]));
toReturn.addNumericPredictors(numericPredictors.toArray(new NumericPredictor[0]));
toReturn.addPredictorTerms(predictorTerms.toArray(new PredictorTerm[0]));
return toReturn;
}
use of org.dmg.pmml.regression.NumericPredictor in project drools by kiegroup.
the class KiePMMLRegressionModelFactoryTest method setup.
@BeforeClass
public static void setup() {
Random random = new Random();
Set<String> fieldNames = new HashSet<>();
regressionTables = IntStream.range(0, 3).mapToObj(i -> {
List<CategoricalPredictor> categoricalPredictors = new ArrayList<>();
List<NumericPredictor> numericPredictors = new ArrayList<>();
List<PredictorTerm> predictorTerms = new ArrayList<>();
IntStream.range(0, 3).forEach(j -> {
String catFieldName = "CatPred-" + j;
String numFieldName = "NumPred-" + j;
categoricalPredictors.add(getCategoricalPredictor(catFieldName, random.nextDouble(), random.nextDouble()));
numericPredictors.add(getNumericPredictor(numFieldName, random.nextInt(), random.nextDouble()));
predictorTerms.add(getPredictorTerm("PredTerm-" + j, random.nextDouble(), Arrays.asList(catFieldName, numFieldName)));
fieldNames.add(catFieldName);
fieldNames.add(numFieldName);
});
return getRegressionTable(categoricalPredictors, numericPredictors, predictorTerms, tableIntercept + random.nextDouble(), tableTargetCategory + "-" + i);
}).collect(Collectors.toList());
dataFields = new ArrayList<>();
miningFields = new ArrayList<>();
fieldNames.forEach(fieldName -> {
dataFields.add(getDataField(fieldName, OpType.CATEGORICAL, DataType.STRING));
miningFields.add(getMiningField(fieldName, MiningField.UsageType.ACTIVE));
});
targetMiningField = miningFields.get(0);
targetMiningField.setUsageType(MiningField.UsageType.TARGET);
dataDictionary = getDataDictionary(dataFields);
transformationDictionary = new TransformationDictionary();
miningSchema = getMiningSchema(miningFields);
regressionModel = getRegressionModel(modelName, MiningFunction.REGRESSION, miningSchema, regressionTables);
COMPILATION_UNIT = getFromFileName(KIE_PMML_REGRESSION_MODEL_TEMPLATE_JAVA);
MODEL_TEMPLATE = COMPILATION_UNIT.getClassByName(KIE_PMML_REGRESSION_MODEL_TEMPLATE).get();
pmml = new PMML();
pmml.setDataDictionary(dataDictionary);
pmml.setTransformationDictionary(transformationDictionary);
pmml.addModels(regressionModel);
}
use of org.dmg.pmml.regression.NumericPredictor in project drools by kiegroup.
the class KiePMMLRegressionTableFactoryTest method getNumericPredictorExpressionWithExponent.
@Test
public void getNumericPredictorExpressionWithExponent() throws IOException {
String predictorName = "predictorName";
int exponent = 2;
double coefficient = 1.23;
NumericPredictor numericPredictor = PMMLModelTestUtils.getNumericPredictor(predictorName, exponent, coefficient);
CastExpr retrieved = KiePMMLRegressionTableFactory.getNumericPredictorExpression(numericPredictor);
String text = getFileContent(TEST_01_SOURCE);
Expression expected = JavaParserUtils.parseExpression(String.format(text, coefficient, exponent));
assertTrue(JavaParserUtils.equalsNode(expected, retrieved));
}
use of org.dmg.pmml.regression.NumericPredictor in project shifu by ShifuML.
the class PMMLAdapterCommonUtil method getRegressionTable.
/**
* Generate Regression Table based on the weight list, intercept and partial
* PMML model
*
* @param weights
* weight list for the Regression Table
* @param intercept
* the intercept
* @param pmmlModel
* partial PMMl model
* @return regression model instance
*/
public static RegressionModel getRegressionTable(final double[] weights, final double intercept, RegressionModel pmmlModel) {
RegressionTable table = new RegressionTable();
MiningSchema schema = pmmlModel.getMiningSchema();
// TODO may not need target field in LRModel
pmmlModel.setMiningFunction(MiningFunction.REGRESSION);
pmmlModel.setNormalizationMethod(NormalizationMethod.LOGIT);
List<String> outputFields = getSchemaFieldViaUsageType(schema, UsageType.TARGET);
// TODO only one outputField, what if we have more than one outputField
pmmlModel.setTargetFieldName(new FieldName(outputFields.get(0)));
table.setTargetCategory(outputFields.get(0));
List<String> activeFields = getSchemaFieldViaUsageType(schema, UsageType.ACTIVE);
int index = 0;
for (DerivedField dField : pmmlModel.getLocalTransformations().getDerivedFields()) {
Expression expression = dField.getExpression();
if (expression instanceof NormContinuous) {
NormContinuous norm = (NormContinuous) expression;
if (activeFields.contains(norm.getField().getValue()))
table.addNumericPredictors(new NumericPredictor(dField.getName(), weights[index++]));
}
}
pmmlModel.addRegressionTables(table);
return pmmlModel;
}
use of org.dmg.pmml.regression.NumericPredictor in project shifu by ShifuML.
the class PMMLLRModelBuilder method adaptMLModelToPMML.
public RegressionModel adaptMLModelToPMML(ml.shifu.shifu.core.LR lr, RegressionModel pmmlModel) {
pmmlModel.setNormalizationMethod(NormalizationMethod.LOGIT);
pmmlModel.setMiningFunction(MiningFunction.REGRESSION);
RegressionTable table = new RegressionTable();
table.setIntercept(lr.getBias());
LocalTransformations lt = pmmlModel.getLocalTransformations();
List<DerivedField> df = lt.getDerivedFields();
HashMap<FieldName, FieldName> miningTransformMap = new HashMap<FieldName, FieldName>();
for (DerivedField dField : df) {
// Apply z-scale normalization on numerical variables
if (dField.getExpression() instanceof NormContinuous) {
miningTransformMap.put(((NormContinuous) dField.getExpression()).getField(), dField.getName());
} else // Apply bin map on categorical variables
if (dField.getExpression() instanceof MapValues) {
miningTransformMap.put(((MapValues) dField.getExpression()).getFieldColumnPairs().get(0).getField(), dField.getName());
} else if (dField.getExpression() instanceof Discretize) {
miningTransformMap.put(((Discretize) dField.getExpression()).getField(), dField.getName());
}
}
List<MiningField> miningList = pmmlModel.getMiningSchema().getMiningFields();
int index = 0;
for (int i = 0; i < miningList.size(); i++) {
MiningField mField = miningList.get(i);
if (mField.getUsageType() != UsageType.ACTIVE)
continue;
FieldName mFieldName = mField.getName();
FieldName fName = mFieldName;
while (miningTransformMap.containsKey(fName)) {
fName = miningTransformMap.get(fName);
}
NumericPredictor np = new NumericPredictor();
np.setName(fName);
np.setCoefficient(lr.getWeights()[index++]);
table.addNumericPredictors(np);
}
pmmlModel.addRegressionTables(table);
return pmmlModel;
}
Aggregations