use of org.dmg.pmml.regression.PredictorTerm in project drools by kiegroup.
the class PMMLModelTestUtils method getPredictorTerm.
public static PredictorTerm getPredictorTerm(String name, double coefficient, List<String> fieldRefNames) {
PredictorTerm toReturn = new PredictorTerm();
toReturn.setName(FieldName.create(name));
toReturn.setCoefficient(coefficient);
toReturn.addFieldRefs(fieldRefNames.stream().map(PMMLModelTestUtils::getFieldRef).toArray(FieldRef[]::new));
return toReturn;
}
use of org.dmg.pmml.regression.PredictorTerm in project drools by kiegroup.
the class KiePMMLRegressionModelFactoryTest method setup.
@BeforeClass
public static void setup() {
Random random = new Random();
Set<String> fieldNames = new HashSet<>();
regressionTables = IntStream.range(0, 3).mapToObj(i -> {
List<CategoricalPredictor> categoricalPredictors = new ArrayList<>();
List<NumericPredictor> numericPredictors = new ArrayList<>();
List<PredictorTerm> predictorTerms = new ArrayList<>();
IntStream.range(0, 3).forEach(j -> {
String catFieldName = "CatPred-" + j;
String numFieldName = "NumPred-" + j;
categoricalPredictors.add(getCategoricalPredictor(catFieldName, random.nextDouble(), random.nextDouble()));
numericPredictors.add(getNumericPredictor(numFieldName, random.nextInt(), random.nextDouble()));
predictorTerms.add(getPredictorTerm("PredTerm-" + j, random.nextDouble(), Arrays.asList(catFieldName, numFieldName)));
fieldNames.add(catFieldName);
fieldNames.add(numFieldName);
});
return getRegressionTable(categoricalPredictors, numericPredictors, predictorTerms, tableIntercept + random.nextDouble(), tableTargetCategory + "-" + i);
}).collect(Collectors.toList());
dataFields = new ArrayList<>();
miningFields = new ArrayList<>();
fieldNames.forEach(fieldName -> {
dataFields.add(getDataField(fieldName, OpType.CATEGORICAL, DataType.STRING));
miningFields.add(getMiningField(fieldName, MiningField.UsageType.ACTIVE));
});
targetMiningField = miningFields.get(0);
targetMiningField.setUsageType(MiningField.UsageType.TARGET);
dataDictionary = getDataDictionary(dataFields);
transformationDictionary = new TransformationDictionary();
miningSchema = getMiningSchema(miningFields);
regressionModel = getRegressionModel(modelName, MiningFunction.REGRESSION, miningSchema, regressionTables);
COMPILATION_UNIT = getFromFileName(KIE_PMML_REGRESSION_MODEL_TEMPLATE_JAVA);
MODEL_TEMPLATE = COMPILATION_UNIT.getClassByName(KIE_PMML_REGRESSION_MODEL_TEMPLATE).get();
pmml = new PMML();
pmml.setDataDictionary(dataDictionary);
pmml.setTransformationDictionary(transformationDictionary);
pmml.addModels(regressionModel);
}
use of org.dmg.pmml.regression.PredictorTerm in project drools by kiegroup.
the class KiePMMLRegressionTableFactoryTest method getPredictorTermsMap.
@Test
public void getPredictorTermsMap() {
final List<PredictorTerm> predictorTerms = IntStream.range(0, 3).mapToObj(index -> {
String predictorName = "predictorName-" + index;
double coefficient = 1.23 * index;
String fieldRef = "fieldRef-" + index;
return PMMLModelTestUtils.getPredictorTerm(predictorName, coefficient, Collections.singletonList(fieldRef));
}).collect(Collectors.toList());
Map<String, SerializableFunction<Map<String, Object>, Double>> retrieved = KiePMMLRegressionTableFactory.getPredictorTermsMap(predictorTerms);
assertEquals(predictorTerms.size(), retrieved.size());
IntStream.range(0, predictorTerms.size()).forEach(index -> {
PredictorTerm predictorTerm = predictorTerms.get(index);
assertTrue(retrieved.containsKey(predictorTerm.getName().getValue()));
});
}
use of org.dmg.pmml.regression.PredictorTerm in project drools by kiegroup.
the class KiePMMLRegressionTableFactoryTest method getPredictorTermFunction.
@Test
public void getPredictorTermFunction() throws IOException {
String predictorName = "predictorName";
double coefficient = 23.12;
String fieldRef = "fieldRef";
PredictorTerm predictorTerm = PMMLModelTestUtils.getPredictorTerm(predictorName, coefficient, Collections.singletonList(fieldRef));
LambdaExpr retrieved = KiePMMLRegressionTableFactory.getPredictorTermFunction(predictorTerm);
String text = getFileContent(TEST_07_SOURCE);
Expression expected = JavaParserUtils.parseExpression(String.format(text, fieldRef, coefficient));
assertTrue(JavaParserUtils.equalsNode(expected, retrieved));
}
Aggregations