use of org.dmg.pmml.regression.PredictorTerm in project drools by kiegroup.
the class KiePMMLRegressionTableFactory method getPredictorTermBody.
/**
* Add a <b>PredictorTerm</b> <code>MethodDeclaration</code> to the class
* @param predictorTerm
* @return
*/
static BlockStmt getPredictorTermBody(final PredictorTerm predictorTerm) {
try {
templateEvaluate = getFromFileName(KIE_PMML_EVALUATE_METHOD_TEMPLATE_JAVA);
cloneEvaluate = templateEvaluate.clone();
ClassOrInterfaceDeclaration evaluateTemplateClass = cloneEvaluate.getClassByName(KIE_PMML_EVALUATE_METHOD_TEMPLATE).orElseThrow(() -> new RuntimeException(MAIN_CLASS_NOT_FOUND));
MethodDeclaration methodTemplate = evaluateTemplateClass.getMethodsByName("evaluatePredictor").get(0);
final BlockStmt body = methodTemplate.getBody().orElseThrow(() -> new KiePMMLInternalException(String.format(MISSING_BODY_TEMPLATE, methodTemplate.getName())));
VariableDeclarator variableDeclarator = getVariableDeclarator(body, "fieldRefs").orElseThrow(() -> new KiePMMLInternalException(String.format(MISSING_VARIABLE_IN_BODY, "fieldRefs", body)));
final List<Expression> nodeList = predictorTerm.getFieldRefs().stream().map(fieldRef -> new StringLiteralExpr(fieldRef.getField().getValue())).collect(Collectors.toList());
NodeList<Expression> expressions = NodeList.nodeList(nodeList);
MethodCallExpr methodCallExpr = new MethodCallExpr(new NameExpr("Arrays"), "asList", expressions);
variableDeclarator.setInitializer(methodCallExpr);
variableDeclarator = getVariableDeclarator(body, COEFFICIENT).orElseThrow(() -> new KiePMMLInternalException(String.format(MISSING_VARIABLE_IN_BODY, COEFFICIENT, body)));
variableDeclarator.setInitializer(String.valueOf(predictorTerm.getCoefficient().doubleValue()));
return methodTemplate.getBody().orElseThrow(() -> new KiePMMLInternalException(String.format(MISSING_BODY_TEMPLATE, methodTemplate.getName())));
} catch (Exception e) {
throw new KiePMMLInternalException(String.format("Failed to add PredictorTerm %s", predictorTerm), e);
}
}
use of org.dmg.pmml.regression.PredictorTerm in project drools by kiegroup.
the class KiePMMLRegressionModelFactoryTest method evaluateRegressionTable.
private void evaluateRegressionTable(KiePMMLRegressionTable regressionTable, RegressionTable originalRegressionTable) {
assertEquals(originalRegressionTable.getIntercept(), regressionTable.getIntercept());
final Map<String, SerializableFunction<Double, Double>> numericFunctionMap = regressionTable.getNumericFunctionMap();
for (NumericPredictor numericPredictor : originalRegressionTable.getNumericPredictors()) {
assertTrue(numericFunctionMap.containsKey(numericPredictor.getName().getValue()));
}
final Map<String, SerializableFunction<String, Double>> categoricalFunctionMap = regressionTable.getCategoricalFunctionMap();
for (CategoricalPredictor categoricalPredictor : originalRegressionTable.getCategoricalPredictors()) {
assertTrue(categoricalFunctionMap.containsKey(categoricalPredictor.getName().getValue()));
}
final Map<String, SerializableFunction<Map<String, Object>, Double>> predictorTermsFunctionMap = regressionTable.getPredictorTermsFunctionMap();
for (PredictorTerm predictorTerm : originalRegressionTable.getPredictorTerms()) {
assertTrue(predictorTermsFunctionMap.containsKey(predictorTerm.getName().getValue()));
}
}
use of org.dmg.pmml.regression.PredictorTerm in project drools by kiegroup.
the class KiePMMLRegressionTableFactoryTest method getPredictorTermFunctions.
@Test
public void getPredictorTermFunctions() {
final List<PredictorTerm> predictorTerms = IntStream.range(0, 3).mapToObj(index -> {
String predictorName = "predictorName-" + index;
double coefficient = 1.23 * index;
String fieldRef = "fieldRef-" + index;
return PMMLModelTestUtils.getPredictorTerm(predictorName, coefficient, Collections.singletonList(fieldRef));
}).collect(Collectors.toList());
Map<String, Expression> retrieved = KiePMMLRegressionTableFactory.getPredictorTermFunctions(predictorTerms);
assertEquals(predictorTerms.size(), retrieved.size());
IntStream.range(0, predictorTerms.size()).forEach(index -> {
PredictorTerm predictorTerm = predictorTerms.get(index);
assertTrue(retrieved.containsKey(predictorTerm.getName().getValue()));
});
}
use of org.dmg.pmml.regression.PredictorTerm in project drools by kiegroup.
the class KiePMMLRegressionTableFactoryTest method getPredictorTermSerializableFunction.
@Test
public void getPredictorTermSerializableFunction() {
String predictorName = "predictorName";
double coefficient = 23.12;
String fieldRef = "fieldRef";
PredictorTerm predictorTerm = PMMLModelTestUtils.getPredictorTerm(predictorName, coefficient, Collections.singletonList(fieldRef));
SerializableFunction<Map<String, Object>, Double> retrieved = KiePMMLRegressionTableFactory.getPredictorTermSerializableFunction(predictorTerm);
assertNotNull(retrieved);
}
use of org.dmg.pmml.regression.PredictorTerm in project drools by kiegroup.
the class PMMLModelTestUtils method getRegressionTable.
public static RegressionTable getRegressionTable(List<CategoricalPredictor> categoricalPredictors, List<NumericPredictor> numericPredictors, List<PredictorTerm> predictorTerms, double intercept, Object targetCategory) {
RegressionTable toReturn = new RegressionTable();
toReturn.setIntercept(intercept);
toReturn.setTargetCategory(targetCategory);
toReturn.addCategoricalPredictors(categoricalPredictors.toArray(new CategoricalPredictor[0]));
toReturn.addNumericPredictors(numericPredictors.toArray(new NumericPredictor[0]));
toReturn.addPredictorTerms(predictorTerms.toArray(new PredictorTerm[0]));
return toReturn;
}
Aggregations