use of org.dmg.pmml.SimplePredicate in project drools by kiegroup.
the class KiePMMLSimplePredicateFactoryTest method getSimplePredicateVariableDeclaration.
@Test
public void getSimplePredicateVariableDeclaration() throws IOException {
String variableName = "variableName";
final SimplePredicate simplePredicate = new SimplePredicate();
simplePredicate.setField(FieldName.create("CUSTOM_FIELD"));
simplePredicate.setValue("235.435");
simplePredicate.setOperator(SimplePredicate.Operator.EQUAL);
String operatorString = OPERATOR.class.getName() + "." + OPERATOR.byName(simplePredicate.getOperator().value());
DataField dataField = new DataField();
dataField.setName(simplePredicate.getField());
dataField.setDataType(DataType.DOUBLE);
DataDictionary dataDictionary = new DataDictionary();
dataDictionary.addDataFields(dataField);
BlockStmt retrieved = KiePMMLSimplePredicateFactory.getSimplePredicateVariableDeclaration(variableName, simplePredicate, getFieldsFromDataDictionary(dataDictionary));
String text = getFileContent(TEST_01_SOURCE);
Statement expected = JavaParserUtils.parseBlock(String.format(text, variableName, simplePredicate.getField().getValue(), operatorString, simplePredicate.getValue()));
assertTrue(JavaParserUtils.equalsNode(expected, retrieved));
List<Class<?>> imports = Arrays.asList(KiePMMLSimplePredicate.class, Collections.class);
commonValidateCompilationWithImports(retrieved, imports);
}
use of org.dmg.pmml.SimplePredicate in project drools by kiegroup.
the class KiePMMLASTFactoryUtilsTest method getConstraintEntryFromSimplePredicates.
@Test
public void getConstraintEntryFromSimplePredicates() {
final Map<String, KiePMMLOriginalTypeGeneratedType> fieldTypeMap = new HashMap<>();
String fieldName = "FIELD_NAME";
List<SimplePredicate> simplePredicates = IntStream.range(0, 2).mapToObj(index -> {
fieldTypeMap.put(fieldName, new KiePMMLOriginalTypeGeneratedType(DataType.STRING.value(), getSanitizedClassName(fieldName.toUpperCase())));
return PMMLModelTestUtils.getSimplePredicate(fieldName, "VALUE-" + index, SimplePredicate.Operator.LESS_THAN);
}).collect(Collectors.toList());
final KiePMMLFieldOperatorValue retrieved = KiePMMLASTFactoryUtils.getConstraintEntryFromSimplePredicates(fieldName, BOOLEAN_OPERATOR.OR, simplePredicates, fieldTypeMap);
assertEquals(fieldName, retrieved.getName());
assertNotNull(retrieved.getConstraintsAsString());
String expected = "value < \"VALUE-0\" || value < \"VALUE-1\"";
assertEquals(expected, retrieved.getConstraintsAsString());
}
use of org.dmg.pmml.SimplePredicate in project drools by kiegroup.
the class KiePMMLCompoundPredicateFactoryTest method getCompoundPredicateVariableDeclaration.
@Test
public void getCompoundPredicateVariableDeclaration() throws IOException {
String variableName = "variableName";
SimplePredicate simplePredicate1 = getSimplePredicate(PARAM_1, value1, operator1);
SimplePredicate simplePredicate2 = getSimplePredicate(PARAM_2, value2, operator2);
Array.Type arrayType = Array.Type.STRING;
List<String> values = getStringObjects(arrayType, 4);
SimpleSetPredicate simpleSetPredicate = getSimpleSetPredicate(values, arrayType, SimpleSetPredicate.BooleanOperator.IS_IN);
CompoundPredicate compoundPredicate = new CompoundPredicate();
compoundPredicate.setBooleanOperator(CompoundPredicate.BooleanOperator.AND);
compoundPredicate.getPredicates().add(0, simplePredicate1);
compoundPredicate.getPredicates().add(1, simplePredicate2);
compoundPredicate.getPredicates().add(2, simpleSetPredicate);
DataField dataField1 = new DataField();
dataField1.setName(simplePredicate1.getField());
dataField1.setDataType(DataType.DOUBLE);
DataField dataField2 = new DataField();
dataField2.setName(simplePredicate2.getField());
dataField2.setDataType(DataType.DOUBLE);
DataField dataField3 = new DataField();
dataField3.setName(simpleSetPredicate.getField());
dataField3.setDataType(DataType.DOUBLE);
DataDictionary dataDictionary = new DataDictionary();
dataDictionary.addDataFields(dataField1, dataField2, dataField3);
String booleanOperatorString = BOOLEAN_OPERATOR.class.getName() + "." + BOOLEAN_OPERATOR.byName(compoundPredicate.getBooleanOperator().value()).name();
String valuesString = values.stream().map(valueString -> "\"" + valueString + "\"").collect(Collectors.joining(","));
final List<Field<?>> fields = getFieldsFromDataDictionary(dataDictionary);
BlockStmt retrieved = KiePMMLCompoundPredicateFactory.getCompoundPredicateVariableDeclaration(variableName, compoundPredicate, fields);
String text = getFileContent(TEST_01_SOURCE);
Statement expected = JavaParserUtils.parseBlock(String.format(text, variableName, valuesString, booleanOperatorString));
assertTrue(JavaParserUtils.equalsNode(expected, retrieved));
List<Class<?>> imports = Arrays.asList(KiePMMLCompoundPredicate.class, KiePMMLSimplePredicate.class, KiePMMLSimpleSetPredicate.class, Arrays.class, Collections.class);
commonValidateCompilationWithImports(retrieved, imports);
}
use of org.dmg.pmml.SimplePredicate in project drools by kiegroup.
the class PMMLModelTestUtils method getSimplePredicate.
public static SimplePredicate getSimplePredicate(final String predicateName, final Object value, final SimplePredicate.Operator operator) {
FieldName fieldName = FieldName.create(predicateName);
SimplePredicate toReturn = new SimplePredicate();
toReturn.setField(fieldName);
toReturn.setOperator(operator);
toReturn.setValue(value);
return toReturn;
}
use of org.dmg.pmml.SimplePredicate in project jpmml-r by jpmml.
the class ScorecardConverter method encodeModel.
@Override
public Scorecard encodeModel(Schema schema) {
RGenericVector glm = getObject();
RDoubleVector coefficients = glm.getDoubleElement("coefficients");
RGenericVector family = glm.getGenericElement("family");
RGenericVector scConf = DecorationUtil.getGenericElement(glm, "sc.conf");
Double intercept = coefficients.getElement(LMConverter.INTERCEPT, false);
List<? extends Feature> features = schema.getFeatures();
SchemaUtil.checkSize(coefficients.size() - (intercept != null ? 1 : 0), features);
RNumberVector<?> odds = scConf.getNumericElement("odds");
RNumberVector<?> basePoints = scConf.getNumericElement("base_points");
RNumberVector<?> pdo = scConf.getNumericElement("pdo");
double factor = (pdo.asScalar()).doubleValue() / Math.log(2);
Map<String, Characteristic> fieldCharacteristics = new LinkedHashMap<>();
for (Feature feature : features) {
String name = feature.getName();
if (!(feature instanceof BinaryFeature)) {
throw new IllegalArgumentException();
}
Double coefficient = getFeatureCoefficient(feature, coefficients);
Characteristic characteristic = fieldCharacteristics.get(name);
if (characteristic == null) {
characteristic = new Characteristic().setName("score(" + FeatureUtil.getName(feature) + ")");
fieldCharacteristics.put(name, characteristic);
}
BinaryFeature binaryFeature = (BinaryFeature) feature;
SimplePredicate simplePredicate = new SimplePredicate(binaryFeature.getName(), SimplePredicate.Operator.EQUAL, binaryFeature.getValue());
Attribute attribute = new Attribute(simplePredicate).setPartialScore(formatScore(-1d * coefficient * factor));
characteristic.addAttributes(attribute);
}
Characteristics characteristics = new Characteristics();
Collection<Map.Entry<String, Characteristic>> entries = fieldCharacteristics.entrySet();
for (Map.Entry<String, Characteristic> entry : entries) {
Characteristic characteristic = entry.getValue();
Attribute attribute = new Attribute(True.INSTANCE).setPartialScore(0d);
characteristic.addAttributes(attribute);
characteristics.addCharacteristics(characteristic);
}
Scorecard scorecard = new Scorecard(MiningFunction.REGRESSION, ModelUtil.createMiningSchema(schema.getLabel()), characteristics).setInitialScore(formatScore((basePoints.asScalar()).doubleValue() - Math.log((odds.asScalar()).doubleValue()) * factor - (intercept != null ? intercept * factor : 0))).setUseReasonCodes(false);
return scorecard;
}
Aggregations