use of org.kie.pmml.api.models.TargetField in project drools by kiegroup.
the class KiePMMLTargetInstanceFactory method getKiePMMLTarget.
public static KiePMMLTarget getKiePMMLTarget(final Target target) {
final List<TargetValue> targetValues = target.hasTargetValues() ? target.getTargetValues().stream().map(KiePMMLTargetInstanceFactory::getKieTargetValue).collect(Collectors.toList()) : Collections.emptyList();
final OP_TYPE opType = target.getOpType() != null ? OP_TYPE.byName(target.getOpType().value()) : null;
final String field = target.getField() != null ? target.getField().getValue() : null;
final CAST_INTEGER castInteger = target.getCastInteger() != null ? CAST_INTEGER.byName(target.getCastInteger().value()) : null;
TargetField targetField = new TargetField(targetValues, opType, field, castInteger, target.getMin(), target.getMax(), target.getRescaleConstant(), target.getRescaleFactor());
final KiePMMLTarget.Builder builder = KiePMMLTarget.builder(targetField.getName(), Collections.emptyList(), targetField);
return builder.build();
}
use of org.kie.pmml.api.models.TargetField in project drools by kiegroup.
the class KiePMMLModelFactoryUtils method init.
/**
* Initialize the given <code>ClassOrInterfaceDeclaration</code> with all the <b>common</b> code needed to
* generate a <code>KiePMMLModel</code>
* @param compilationDTO
* @param modelTemplate
*/
public static void init(final CompilationDTO<? extends Model> compilationDTO, final ClassOrInterfaceDeclaration modelTemplate) {
final ConstructorDeclaration constructorDeclaration = modelTemplate.getDefaultConstructor().orElseThrow(() -> new KiePMMLInternalException(String.format(MISSING_DEFAULT_CONSTRUCTOR, modelTemplate.getName())));
final String name = compilationDTO.getModelName();
final String generatedClassName = compilationDTO.getSimpleClassName();
final List<MiningField> miningFields = compilationDTO.getKieMiningFields();
final List<OutputField> outputFields = compilationDTO.getKieOutputFields();
final List<TargetField> targetFields = compilationDTO.getKieTargetFields();
final Expression miningFunctionExpression;
if (compilationDTO.getMINING_FUNCTION() != null) {
MINING_FUNCTION miningFunction = compilationDTO.getMINING_FUNCTION();
miningFunctionExpression = new NameExpr(miningFunction.getClass().getName() + "." + miningFunction.name());
} else {
miningFunctionExpression = new NullLiteralExpr();
}
final PMML_MODEL pmmlModelEnum = compilationDTO.getPMML_MODEL();
final NameExpr pmmlMODELExpression = new NameExpr(pmmlModelEnum.getClass().getName() + "." + pmmlModelEnum.name());
String targetFieldName = compilationDTO.getTargetFieldName();
final Expression targetFieldExpression;
if (targetFieldName != null) {
targetFieldExpression = new StringLiteralExpr(targetFieldName);
} else {
targetFieldExpression = new NullLiteralExpr();
}
setKiePMMLModelConstructor(generatedClassName, constructorDeclaration, name, miningFields, outputFields, targetFields);
addTransformationsInClassOrInterfaceDeclaration(modelTemplate, compilationDTO.getTransformationDictionary(), compilationDTO.getLocalTransformations());
final BlockStmt body = constructorDeclaration.getBody();
CommonCodegenUtils.setAssignExpressionValue(body, "pmmlMODEL", pmmlMODELExpression);
CommonCodegenUtils.setAssignExpressionValue(body, "miningFunction", miningFunctionExpression);
CommonCodegenUtils.setAssignExpressionValue(body, "targetField", targetFieldExpression);
addGetCreatedKiePMMLMiningFieldsMethod(modelTemplate, compilationDTO.getMiningSchema().getMiningFields(), compilationDTO.getFields());
MethodCallExpr getCreatedKiePMMLMiningFieldsExpr = new MethodCallExpr();
getCreatedKiePMMLMiningFieldsExpr.setScope(new ThisExpr());
getCreatedKiePMMLMiningFieldsExpr.setName(GET_CREATED_KIEPMMLMININGFIELDS);
CommonCodegenUtils.setAssignExpressionValue(body, "kiePMMLMiningFields", getCreatedKiePMMLMiningFieldsExpr);
if (compilationDTO.getOutput() != null) {
addGetCreatedKiePMMLOutputFieldsMethod(modelTemplate, compilationDTO.getOutput().getOutputFields());
MethodCallExpr getCreatedKiePMMLOutputFieldsExpr = new MethodCallExpr();
getCreatedKiePMMLOutputFieldsExpr.setScope(new ThisExpr());
getCreatedKiePMMLOutputFieldsExpr.setName(GET_CREATED_KIEPMMLOUTPUTFIELDS);
CommonCodegenUtils.setAssignExpressionValue(body, "kiePMMLOutputFields", getCreatedKiePMMLOutputFieldsExpr);
}
}
use of org.kie.pmml.api.models.TargetField in project drools by kiegroup.
the class KiePMMLTargetTest method applyCastInteger.
@Test
public void applyCastInteger() {
TargetField targetField = new TargetField(Collections.emptyList(), null, "string", null, null, null, null, null);
KiePMMLTarget kiePMMLTarget = getBuilder(targetField).build();
assertEquals(2.718, (double) kiePMMLTarget.applyCastInteger(2.718), 0.0);
targetField = new TargetField(Collections.emptyList(), null, "string", CAST_INTEGER.ROUND, null, null, null, null);
kiePMMLTarget = getBuilder(targetField).build();
assertEquals(3.0, (double) kiePMMLTarget.applyCastInteger(2.718), 0.0);
}
use of org.kie.pmml.api.models.TargetField in project drools by kiegroup.
the class KiePMMLTargetTest method applyMin.
@Test
public void applyMin() {
TargetField targetField = new TargetField(Collections.emptyList(), null, "string", null, null, null, null, null);
KiePMMLTarget kiePMMLTarget = getBuilder(targetField).build();
assertEquals(4.33, kiePMMLTarget.applyMin(4.33), 0.0);
targetField = new TargetField(Collections.emptyList(), null, "string", null, 4.34, null, null, null);
kiePMMLTarget = getBuilder(targetField).build();
assertEquals(4.34, kiePMMLTarget.applyMin(4.33), 0.0);
assertEquals(4.35, kiePMMLTarget.applyMin(4.35), 0.0);
}
use of org.kie.pmml.api.models.TargetField in project drools by kiegroup.
the class KiePMMLTargetTest method applyRescaleFactor.
@Test
public void applyRescaleFactor() {
TargetField targetField = new TargetField(Collections.emptyList(), null, "string", null, null, null, null, null);
KiePMMLTarget kiePMMLTarget = getBuilder(targetField).build();
assertEquals(4.0, kiePMMLTarget.applyRescaleFactor(4.0), 0.0);
targetField = new TargetField(Collections.emptyList(), null, "string", null, null, null, null, 2.0);
kiePMMLTarget = getBuilder(targetField).build();
assertEquals(8.0, kiePMMLTarget.applyRescaleFactor(4.0), 0.0);
}
Aggregations