use of org.dmg.pmml.FieldRef in project drools by kiegroup.
the class KiePMMLDerivedFieldFactoryTest method getDerivedFieldVariableDeclarationWithApply.
@Test
public void getDerivedFieldVariableDeclarationWithApply() throws IOException {
final String variableName = "variableName";
Constant constant = new Constant();
constant.setValue(value1);
FieldRef fieldRef = new FieldRef();
fieldRef.setField(FieldName.create("FIELD_REF"));
Apply apply = new Apply();
apply.setFunction("/");
apply.addExpressions(constant, fieldRef);
DerivedField derivedField = new DerivedField();
derivedField.setName(FieldName.create(PARAM_1));
derivedField.setDataType(DataType.DOUBLE);
derivedField.setOpType(OpType.CONTINUOUS);
derivedField.setExpression(apply);
String dataType = getDATA_TYPEString(derivedField.getDataType());
String opType = getOP_TYPEString(derivedField.getOpType());
BlockStmt retrieved = KiePMMLDerivedFieldFactory.getDerivedFieldVariableDeclaration(variableName, derivedField);
String text = getFileContent(TEST_03_SOURCE);
Statement expected = JavaParserUtils.parseBlock(String.format(text, constant.getValue(), fieldRef.getField().getValue(), apply.getFunction(), apply.getInvalidValueTreatment().value(), variableName, derivedField.getName().getValue(), dataType, opType));
assertTrue(JavaParserUtils.equalsNode(expected, retrieved));
List<Class<?>> imports = Arrays.asList(KiePMMLConstant.class, KiePMMLFieldRef.class, KiePMMLApply.class, KiePMMLDerivedField.class, Arrays.class, Collections.class);
commonValidateCompilationWithImports(retrieved, imports);
}
use of org.dmg.pmml.FieldRef in project drools by kiegroup.
the class PMMLModelTestUtils method getPredictorTerm.
public static PredictorTerm getPredictorTerm(String name, double coefficient, List<String> fieldRefNames) {
PredictorTerm toReturn = new PredictorTerm();
toReturn.setName(FieldName.create(name));
toReturn.setCoefficient(coefficient);
toReturn.addFieldRefs(fieldRefNames.stream().map(PMMLModelTestUtils::getFieldRef).toArray(FieldRef[]::new));
return toReturn;
}
use of org.dmg.pmml.FieldRef in project drools by kiegroup.
the class KiePMMLComplexPartialScoreFactoryTest method getComplexPartialScoreVariableDeclarationWithApply.
@Test
public void getComplexPartialScoreVariableDeclarationWithApply() throws IOException {
final String variableName = "variableName";
Constant constant = new Constant();
constant.setValue(value1);
FieldRef fieldRef = new FieldRef();
fieldRef.setField(FieldName.create("FIELD_REF"));
Apply apply = new Apply();
apply.setFunction("/");
apply.addExpressions(constant, fieldRef);
ComplexPartialScore complexPartialScore = new ComplexPartialScore();
complexPartialScore.setExpression(apply);
BlockStmt retrieved = KiePMMLComplexPartialScoreFactory.getComplexPartialScoreVariableDeclaration(variableName, complexPartialScore);
String text = getFileContent(TEST_03_SOURCE);
Statement expected = JavaParserUtils.parseBlock(String.format(text, constant.getValue(), fieldRef.getField().getValue(), apply.getFunction(), apply.getInvalidValueTreatment().value(), variableName));
assertTrue(JavaParserUtils.equalsNode(expected, retrieved));
List<Class<?>> imports = Arrays.asList(KiePMMLConstant.class, KiePMMLFieldRef.class, KiePMMLApply.class, KiePMMLComplexPartialScore.class, Arrays.class, Collections.class);
commonValidateCompilationWithImports(retrieved, imports);
}
use of org.dmg.pmml.FieldRef in project drools by kiegroup.
the class KiePMMLComplexPartialScoreFactoryTest method getComplexPartialScoreVariableDeclarationWithFieldRef.
@Test
public void getComplexPartialScoreVariableDeclarationWithFieldRef() throws IOException {
final String variableName = "variableName";
FieldRef fieldRef = new FieldRef();
fieldRef.setField(FieldName.create("FIELD_REF"));
ComplexPartialScore complexPartialScore = new ComplexPartialScore();
complexPartialScore.setExpression(fieldRef);
BlockStmt retrieved = KiePMMLComplexPartialScoreFactory.getComplexPartialScoreVariableDeclaration(variableName, complexPartialScore);
String text = getFileContent(TEST_02_SOURCE);
Statement expected = JavaParserUtils.parseBlock(String.format(text, fieldRef.getField().getValue(), variableName));
assertTrue(JavaParserUtils.equalsNode(expected, retrieved));
List<Class<?>> imports = Arrays.asList(KiePMMLFieldRef.class, KiePMMLComplexPartialScore.class, Collections.class);
commonValidateCompilationWithImports(retrieved, imports);
}
use of org.dmg.pmml.FieldRef in project shifu by ShifuML.
the class NeuralNetworkModelIntegrator method getNeuralInputs.
private NeuralInputs getNeuralInputs(final NeuralNetwork model) {
NeuralInputs nnInputs = new NeuralInputs();
// get HashMap for local transform and MiningSchema fields
HashMap<FieldName, FieldName> reversMiningTransformMap = new HashMap<FieldName, FieldName>();
HashMap<FieldName, List<FieldName>> treeMapOfTransform = new HashMap<FieldName, List<FieldName>>();
for (DerivedField dField : model.getLocalTransformations().getDerivedFields()) {
// Apply z-scale normalization on numerical variables
FieldName parentField = null;
if (dField.getExpression() instanceof NormContinuous) {
parentField = ((NormContinuous) dField.getExpression()).getField();
reversMiningTransformMap.put(dField.getName(), parentField);
} else // Apply bin map on categorical variables
if (dField.getExpression() instanceof MapValues) {
parentField = ((MapValues) dField.getExpression()).getFieldColumnPairs().get(0).getField();
reversMiningTransformMap.put(dField.getName(), parentField);
} else if (dField.getExpression() instanceof Discretize) {
parentField = ((Discretize) dField.getExpression()).getField();
reversMiningTransformMap.put(dField.getName(), parentField);
}
List<FieldName> fieldNames = treeMapOfTransform.get(parentField);
if (fieldNames == null) {
fieldNames = new ArrayList<FieldName>();
}
fieldNames.add(dField.getName());
treeMapOfTransform.put(parentField, fieldNames);
}
// comment here
List<MiningField> miningList = model.getMiningSchema().getMiningFields();
int index = 0;
for (DerivedField dField : model.getLocalTransformations().getDerivedFields()) {
List<FieldName> list = treeMapOfTransform.get(dField.getName());
boolean isLeaf = (list == null || list.size() == 0);
FieldName root = getRoot(dField.getName(), reversMiningTransformMap);
if (isLeaf && isRootInMiningList(root, miningList)) {
DerivedField field = new DerivedField(OpType.CONTINUOUS, DataType.DOUBLE).setName(dField.getName()).setExpression(new FieldRef(dField.getName()));
nnInputs.addNeuralInputs(new NeuralInput("0," + (index++), field));
}
}
DerivedField field = new DerivedField(OpType.CONTINUOUS, DataType.DOUBLE).setName(new FieldName(PluginConstants.biasValue)).setExpression(new FieldRef(new FieldName(PluginConstants.biasValue)));
nnInputs.addNeuralInputs(new NeuralInput(PluginConstants.biasValue, field));
return nnInputs;
}
Aggregations