use of org.dmg.pmml.regression.CategoricalPredictor in project drools by kiegroup.
the class KiePMMLRegressionModelFactoryTest method evaluateRegressionTable.
private void evaluateRegressionTable(KiePMMLRegressionTable regressionTable, RegressionTable originalRegressionTable) {
assertEquals(originalRegressionTable.getIntercept(), regressionTable.getIntercept());
final Map<String, SerializableFunction<Double, Double>> numericFunctionMap = regressionTable.getNumericFunctionMap();
for (NumericPredictor numericPredictor : originalRegressionTable.getNumericPredictors()) {
assertTrue(numericFunctionMap.containsKey(numericPredictor.getName().getValue()));
}
final Map<String, SerializableFunction<String, Double>> categoricalFunctionMap = regressionTable.getCategoricalFunctionMap();
for (CategoricalPredictor categoricalPredictor : originalRegressionTable.getCategoricalPredictors()) {
assertTrue(categoricalFunctionMap.containsKey(categoricalPredictor.getName().getValue()));
}
final Map<String, SerializableFunction<Map<String, Object>, Double>> predictorTermsFunctionMap = regressionTable.getPredictorTermsFunctionMap();
for (PredictorTerm predictorTerm : originalRegressionTable.getPredictorTerms()) {
assertTrue(predictorTermsFunctionMap.containsKey(predictorTerm.getName().getValue()));
}
}
use of org.dmg.pmml.regression.CategoricalPredictor in project drools by kiegroup.
the class KiePMMLRegressionTableFactoryTest method getCategoricalPredictorsExpressions.
@Test
public void getCategoricalPredictorsExpressions() {
final List<CategoricalPredictor> categoricalPredictors = IntStream.range(0, 3).mapToObj(index -> IntStream.range(0, 3).mapToObj(i -> {
String predictorName = "predictorName-" + index;
double coefficient = 1.23 * i;
return PMMLModelTestUtils.getCategoricalPredictor(predictorName, i, coefficient);
}).collect(Collectors.toList())).reduce((categoricalPredictors1, categoricalPredictors2) -> {
List<CategoricalPredictor> toReturn = new ArrayList<>();
toReturn.addAll(categoricalPredictors1);
toReturn.addAll(categoricalPredictors2);
return toReturn;
}).get();
final BlockStmt body = new BlockStmt();
Map<String, Expression> retrieved = KiePMMLRegressionTableFactory.getCategoricalPredictorsExpressions(categoricalPredictors, body, "variableName");
assertEquals(3, retrieved.size());
final Map<String, List<CategoricalPredictor>> groupedCollectors = categoricalPredictors.stream().collect(groupingBy(categoricalPredictor -> categoricalPredictor.getField().getValue()));
groupedCollectors.values().forEach(categoricalPredictors12 -> commonEvaluateCategoryPredictors(body, categoricalPredictors12, "variableName"));
}
use of org.dmg.pmml.regression.CategoricalPredictor in project drools by kiegroup.
the class KiePMMLRegressionTableFactoryTest method commonEvaluateCategoryPredictors.
private void commonEvaluateCategoryPredictors(final BlockStmt toVerify, final List<CategoricalPredictor> categoricalPredictors, final String variableName) {
for (int i = 0; i < categoricalPredictors.size(); i++) {
CategoricalPredictor categoricalPredictor = categoricalPredictors.get(i);
String expectedVariableName = getSanitizedVariableName(String.format("%sMap", variableName)) + "_" + i;
assertTrue(toVerify.getStatements().stream().anyMatch(statement -> {
String expected = String.format("%s.put(\"%s\", %s);", expectedVariableName, categoricalPredictor.getValue(), categoricalPredictor.getCoefficient());
return statement instanceof ExpressionStmt && ((ExpressionStmt) statement).getExpression() instanceof MethodCallExpr && statement.toString().equals(expected);
}));
}
}
use of org.dmg.pmml.regression.CategoricalPredictor in project drools by kiegroup.
the class PMMLModelTestUtils method getCategoricalPredictor.
public static CategoricalPredictor getCategoricalPredictor(String name, double value, double coefficient) {
CategoricalPredictor toReturn = new CategoricalPredictor();
toReturn.setField(FieldName.create(name));
toReturn.setValue(value);
toReturn.setCoefficient(coefficient);
return toReturn;
}
use of org.dmg.pmml.regression.CategoricalPredictor in project drools by kiegroup.
the class PMMLModelTestUtils method getRegressionTable.
public static RegressionTable getRegressionTable(List<CategoricalPredictor> categoricalPredictors, List<NumericPredictor> numericPredictors, List<PredictorTerm> predictorTerms, double intercept, Object targetCategory) {
RegressionTable toReturn = new RegressionTable();
toReturn.setIntercept(intercept);
toReturn.setTargetCategory(targetCategory);
toReturn.addCategoricalPredictors(categoricalPredictors.toArray(new CategoricalPredictor[0]));
toReturn.addNumericPredictors(numericPredictors.toArray(new NumericPredictor[0]));
toReturn.addPredictorTerms(predictorTerms.toArray(new PredictorTerm[0]));
return toReturn;
}
Aggregations