use of org.dmg.pmml.DataField in project drools by kiegroup.
the class KiePMMLClassificationTableFactoryTest method setStaticGetter.
@Test
public void setStaticGetter() throws IOException {
String variableName = "variableName";
RegressionTable regressionTableProf = getRegressionTable(3.5, "professional");
RegressionTable regressionTableCler = getRegressionTable(27.4, "clerical");
OutputField outputFieldCat = getOutputField("CAT-1", ResultFeature.PROBABILITY, "CatPred-1");
OutputField outputFieldNum = getOutputField("NUM-1", ResultFeature.PROBABILITY, "NumPred-0");
OutputField outputFieldPrev = getOutputField("PREV", ResultFeature.PREDICTED_VALUE, null);
String targetField = "targetField";
DataField dataField = new DataField();
dataField.setName(FieldName.create(targetField));
dataField.setOpType(OpType.CATEGORICAL);
DataDictionary dataDictionary = new DataDictionary();
dataDictionary.addDataFields(dataField);
RegressionModel regressionModel = new RegressionModel();
regressionModel.setNormalizationMethod(RegressionModel.NormalizationMethod.CAUCHIT);
regressionModel.addRegressionTables(regressionTableProf, regressionTableCler);
regressionModel.setModelName(getGeneratedClassName("RegressionModel"));
Output output = new Output();
output.addOutputFields(outputFieldCat, outputFieldNum, outputFieldPrev);
regressionModel.setOutput(output);
MiningField miningField = new MiningField();
miningField.setUsageType(MiningField.UsageType.TARGET);
miningField.setName(dataField.getName());
MiningSchema miningSchema = new MiningSchema();
miningSchema.addMiningFields(miningField);
regressionModel.setMiningSchema(miningSchema);
PMML pmml = new PMML();
pmml.setDataDictionary(dataDictionary);
pmml.addModels(regressionModel);
final CommonCompilationDTO<RegressionModel> source = CommonCompilationDTO.fromGeneratedPackageNameAndFields(PACKAGE_NAME, pmml, regressionModel, new HasClassLoaderMock());
final RegressionCompilationDTO compilationDTO = RegressionCompilationDTO.fromCompilationDTORegressionTablesAndNormalizationMethod(source, regressionModel.getRegressionTables(), regressionModel.getNormalizationMethod());
final LinkedHashMap<String, KiePMMLTableSourceCategory> regressionTablesMap = new LinkedHashMap<>();
regressionModel.getRegressionTables().forEach(regressionTable -> {
String key = "defpack." + regressionTable.getTargetCategory().toString().toUpperCase();
KiePMMLTableSourceCategory value = new KiePMMLTableSourceCategory("", regressionTable.getTargetCategory().toString());
regressionTablesMap.put(key, value);
});
final MethodDeclaration staticGetterMethod = STATIC_GETTER_METHOD.clone();
KiePMMLClassificationTableFactory.setStaticGetter(compilationDTO, regressionTablesMap, staticGetterMethod, variableName);
String text = getFileContent(TEST_02_SOURCE);
MethodDeclaration expected = JavaParserUtils.parseMethod(text);
assertTrue(JavaParserUtils.equalsNode(expected, staticGetterMethod));
}
use of org.dmg.pmml.DataField in project drools by kiegroup.
the class RegressionCompilationDTO method isRegression.
public boolean isRegression() {
final DataField targetDataField = getTargetDataField();
final OpType targetOpType = targetDataField != null ? targetDataField.getOpType() : null;
return Objects.equals(MiningFunction.REGRESSION, getMiningFunction()) && (targetDataField == null || Objects.equals(OpType.CONTINUOUS, targetOpType));
}
use of org.dmg.pmml.DataField in project drools by kiegroup.
the class KiePMMLCharacteristicFactoryTest method getAttributeVariableDeclarationWithComplexPartialScore.
@Test
public void getAttributeVariableDeclarationWithComplexPartialScore() throws IOException {
final String variableName = "variableName";
Array.Type arrayType = Array.Type.STRING;
List<String> values1 = getStringObjects(arrayType, 4);
Attribute attribute1 = getAttribute(values1, 1);
List<String> values2 = getStringObjects(arrayType, 4);
Attribute attribute2 = getAttribute(values2, 2);
CompoundPredicate compoundPredicate1 = (CompoundPredicate) attribute1.getPredicate();
CompoundPredicate compoundPredicate2 = (CompoundPredicate) attribute2.getPredicate();
DataDictionary dataDictionary = new DataDictionary();
for (Predicate predicate : compoundPredicate1.getPredicates()) {
DataField toAdd = null;
if (predicate instanceof SimplePredicate) {
toAdd = new DataField();
toAdd.setName(((SimplePredicate) predicate).getField());
toAdd.setDataType(DataType.DOUBLE);
} else if (predicate instanceof SimpleSetPredicate) {
toAdd = new DataField();
toAdd.setName(((SimpleSetPredicate) predicate).getField());
toAdd.setDataType(DataType.DOUBLE);
}
if (toAdd != null) {
dataDictionary.addDataFields(toAdd);
}
}
for (Predicate predicate : compoundPredicate2.getPredicates()) {
DataField toAdd = null;
if (predicate instanceof SimplePredicate) {
toAdd = new DataField();
toAdd.setName(((SimplePredicate) predicate).getField());
toAdd.setDataType(DataType.DOUBLE);
} else if (predicate instanceof SimpleSetPredicate) {
toAdd = new DataField();
toAdd.setName(((SimpleSetPredicate) predicate).getField());
toAdd.setDataType(DataType.DOUBLE);
}
if (toAdd != null) {
dataDictionary.addDataFields(toAdd);
}
}
String valuesString1 = values1.stream().map(valueString -> "\"" + valueString + "\"").collect(Collectors.joining(","));
String valuesString2 = values2.stream().map(valueString -> "\"" + valueString + "\"").collect(Collectors.joining(","));
Characteristic characteristic = new Characteristic();
characteristic.addAttributes(attribute1, attribute2);
characteristic.setBaselineScore(22);
characteristic.setReasonCode(REASON_CODE);
BlockStmt retrieved = KiePMMLCharacteristicFactory.getCharacteristicVariableDeclaration(variableName, characteristic, getFieldsFromDataDictionary(dataDictionary));
String text = getFileContent(TEST_01_SOURCE);
Statement expected = JavaParserUtils.parseBlock(String.format(text, variableName, valuesString1, valuesString2, characteristic.getBaselineScore(), characteristic.getReasonCode()));
assertTrue(JavaParserUtils.equalsNode(expected, retrieved));
List<Class<?>> imports = Arrays.asList(KiePMMLAttribute.class, KiePMMLCharacteristic.class, KiePMMLComplexPartialScore.class, KiePMMLCompoundPredicate.class, KiePMMLConstant.class, KiePMMLSimplePredicate.class, KiePMMLSimpleSetPredicate.class, Arrays.class, Collections.class);
commonValidateCompilationWithImports(retrieved, imports);
}
use of org.dmg.pmml.DataField in project jpmml-r by jpmml.
the class FormulaUtil method createFormula.
public static Formula createFormula(RExp terms, FormulaContext context, RExpEncoder encoder) {
Formula formula = new Formula(encoder);
RIntegerVector factors = (RIntegerVector) terms.getAttributeValue("factors");
RStringVector dataClasses = (RStringVector) terms.getAttributeValue("dataClasses");
RStringVector variableRows = factors.dimnames(0);
RStringVector termColumns = factors.dimnames(1);
VariableMap expressionFields = new VariableMap();
for (int i = 0; i < variableRows.size(); i++) {
String variable = variableRows.getDequotedValue(i);
FieldName name = FieldName.create(variable);
OpType opType = OpType.CONTINUOUS;
DataType dataType = RExpUtil.getDataType(dataClasses.getValue(variable));
List<String> categories = context.getCategories(variable);
if (categories != null && categories.size() > 0) {
opType = OpType.CATEGORICAL;
}
Expression expression = null;
FieldName shortName = name;
expression: if (variable.indexOf('(') > -1 && variable.indexOf(')') > -1) {
FunctionExpression functionExpression;
try {
functionExpression = (FunctionExpression) ExpressionTranslator.translateExpression(variable);
} catch (Exception e) {
break expression;
}
if (functionExpression.hasId("base", "cut")) {
expression = encodeCutExpression(functionExpression, categories, expressionFields, encoder);
} else if (functionExpression.hasId("base", "I")) {
expression = encodeIdentityExpression(functionExpression, expressionFields, encoder);
} else if (functionExpression.hasId("base", "ifelse")) {
expression = encodeIfElseExpression(functionExpression, expressionFields, encoder);
} else if (functionExpression.hasId("plyr", "mapvalues")) {
expression = encodeMapValuesExpression(functionExpression, categories, expressionFields, encoder);
} else if (functionExpression.hasId("plyr", "revalue")) {
expression = encodeReValueExpression(functionExpression, categories, expressionFields, encoder);
} else {
break expression;
}
FunctionExpression.Argument xArgument = functionExpression.getArgument("x", 0);
String value = (xArgument.formatExpression()).trim();
shortName = FieldName.create(functionExpression.hasId("base", "I") ? value : (functionExpression.getFunction() + "(" + value + ")"));
}
if (expression != null) {
DerivedField derivedField = encoder.createDerivedField(name, opType, dataType, expression).addExtensions(createExtension(variable));
if (categories != null && categories.size() > 0) {
formula.addField(derivedField, categories);
} else {
formula.addField(derivedField);
}
if (!(name).equals(shortName)) {
encoder.renameField(name, shortName);
}
} else {
if ((DataType.BOOLEAN).equals(dataType)) {
categories = Arrays.asList("false", "true");
}
if (categories != null && categories.size() > 0) {
DataField dataField = encoder.createDataField(name, OpType.CATEGORICAL, dataType, categories);
List<String> categoryNames;
List<String> categoryValues;
switch(dataType) {
case BOOLEAN:
categoryNames = Arrays.asList("FALSE", "TRUE");
categoryValues = Arrays.asList("false", "true");
break;
default:
categoryNames = categories;
categoryValues = categories;
break;
}
formula.addField(dataField, categoryNames, categoryValues);
} else {
DataField dataField = encoder.createDataField(name, OpType.CONTINUOUS, dataType);
formula.addField(dataField);
}
}
}
Collection<Map.Entry<FieldName, List<String>>> entries = expressionFields.entrySet();
for (Map.Entry<FieldName, List<String>> entry : entries) {
FieldName name = entry.getKey();
List<String> categories = entry.getValue();
DataField dataField = encoder.getDataField(name);
if (dataField == null) {
OpType opType = OpType.CONTINUOUS;
DataType dataType = DataType.DOUBLE;
if (categories != null && categories.size() > 0) {
opType = OpType.CATEGORICAL;
}
RGenericVector data = context.getData();
if (data != null && data.hasValue(name.getValue())) {
RVector<?> column = (RVector<?>) data.getValue(name.getValue());
dataType = column.getDataType();
}
dataField = encoder.createDataField(name, opType, dataType, categories);
}
}
return formula;
}
use of org.dmg.pmml.DataField in project jpmml-r by jpmml.
the class GLMConverter method encodeSchema.
@Override
public void encodeSchema(RExpEncoder encoder) {
RGenericVector glm = getObject();
RGenericVector family = (RGenericVector) glm.getValue("family");
RGenericVector model = (RGenericVector) glm.getValue("model");
RStringVector familyFamily = (RStringVector) family.getValue("family");
super.encodeSchema(encoder);
MiningFunction miningFunction = getMiningFunction(familyFamily.asScalar());
switch(miningFunction) {
case CLASSIFICATION:
Label label = encoder.getLabel();
RIntegerVector variable = (RIntegerVector) model.getValue((label.getName()).getValue());
DataField dataField = (DataField) encoder.toCategorical(label.getName(), RExpUtil.getFactorLevels(variable));
encoder.setLabel(dataField);
break;
default:
break;
}
}
Aggregations