use of org.dmg.pmml.regression.CategoricalPredictor in project drools by kiegroup.
the class KiePMMLRegressionTableFactory method getCategoricalPredictorsExpressions.
/**
* Create the <b>CategoricalPredictor</b>s lambda <code>Expression</code>s map
* @param categoricalPredictors
* @param body
* @return
*/
static Map<String, Expression> getCategoricalPredictorsExpressions(final List<CategoricalPredictor> categoricalPredictors, final BlockStmt body, final String variableName) {
final Map<String, List<CategoricalPredictor>> groupedCollectors = categoricalPredictors.stream().collect(groupingBy(categoricalPredictor -> categoricalPredictor.getField().getValue()));
final String categoricalPredictorMapNameBase = getSanitizedVariableName(String.format("%sMap", variableName));
final AtomicInteger counter = new AtomicInteger();
return groupedCollectors.entrySet().stream().map(entry -> {
final String categoricalPredictorMapName = String.format(VARIABLE_NAME_TEMPLATE, categoricalPredictorMapNameBase, counter.getAndIncrement());
populateWithGroupedCategoricalPredictorMap(entry.getValue(), body, categoricalPredictorMapName);
return new AbstractMap.SimpleEntry<>(entry.getKey(), getCategoricalPredictorExpression(categoricalPredictorMapName));
}).collect(Collectors.toMap(AbstractMap.SimpleEntry::getKey, AbstractMap.SimpleEntry::getValue));
}
use of org.dmg.pmml.regression.CategoricalPredictor in project drools by kiegroup.
the class KiePMMLRegressionModelFactoryTest method setup.
@BeforeClass
public static void setup() {
Random random = new Random();
Set<String> fieldNames = new HashSet<>();
regressionTables = IntStream.range(0, 3).mapToObj(i -> {
List<CategoricalPredictor> categoricalPredictors = new ArrayList<>();
List<NumericPredictor> numericPredictors = new ArrayList<>();
List<PredictorTerm> predictorTerms = new ArrayList<>();
IntStream.range(0, 3).forEach(j -> {
String catFieldName = "CatPred-" + j;
String numFieldName = "NumPred-" + j;
categoricalPredictors.add(getCategoricalPredictor(catFieldName, random.nextDouble(), random.nextDouble()));
numericPredictors.add(getNumericPredictor(numFieldName, random.nextInt(), random.nextDouble()));
predictorTerms.add(getPredictorTerm("PredTerm-" + j, random.nextDouble(), Arrays.asList(catFieldName, numFieldName)));
fieldNames.add(catFieldName);
fieldNames.add(numFieldName);
});
return getRegressionTable(categoricalPredictors, numericPredictors, predictorTerms, tableIntercept + random.nextDouble(), tableTargetCategory + "-" + i);
}).collect(Collectors.toList());
dataFields = new ArrayList<>();
miningFields = new ArrayList<>();
fieldNames.forEach(fieldName -> {
dataFields.add(getDataField(fieldName, OpType.CATEGORICAL, DataType.STRING));
miningFields.add(getMiningField(fieldName, MiningField.UsageType.ACTIVE));
});
targetMiningField = miningFields.get(0);
targetMiningField.setUsageType(MiningField.UsageType.TARGET);
dataDictionary = getDataDictionary(dataFields);
transformationDictionary = new TransformationDictionary();
miningSchema = getMiningSchema(miningFields);
regressionModel = getRegressionModel(modelName, MiningFunction.REGRESSION, miningSchema, regressionTables);
COMPILATION_UNIT = getFromFileName(KIE_PMML_REGRESSION_MODEL_TEMPLATE_JAVA);
MODEL_TEMPLATE = COMPILATION_UNIT.getClassByName(KIE_PMML_REGRESSION_MODEL_TEMPLATE).get();
pmml = new PMML();
pmml.setDataDictionary(dataDictionary);
pmml.setTransformationDictionary(transformationDictionary);
pmml.addModels(regressionModel);
}
use of org.dmg.pmml.regression.CategoricalPredictor in project drools by kiegroup.
the class KiePMMLRegressionTableFactoryTest method populateWithGroupedCategoricalPredictorMap.
@Test
public void populateWithGroupedCategoricalPredictorMap() throws IOException {
final List<CategoricalPredictor> categoricalPredictors = new ArrayList<>();
for (int i = 0; i < 3; i++) {
String predictorName = "predictorName-" + i;
double coefficient = 1.23 * i;
categoricalPredictors.add(PMMLModelTestUtils.getCategoricalPredictor(predictorName, i, coefficient));
}
final BlockStmt toPopulate = new BlockStmt();
final String categoricalPredictorMapName = "categoricalPredictorMapName";
KiePMMLRegressionTableFactory.populateWithGroupedCategoricalPredictorMap(categoricalPredictors, toPopulate, categoricalPredictorMapName);
String text = getFileContent(TEST_04_SOURCE);
BlockStmt expected = JavaParserUtils.parseBlock(String.format(text, categoricalPredictorMapName, categoricalPredictors.get(0).getValue(), categoricalPredictors.get(0).getCoefficient(), categoricalPredictors.get(1).getValue(), categoricalPredictors.get(1).getCoefficient(), categoricalPredictors.get(2).getValue(), categoricalPredictors.get(2).getCoefficient()));
assertTrue(JavaParserUtils.equalsNode(expected, toPopulate));
}
Aggregations