use of org.dmg.pmml.DefineFunction in project jpmml-sparkml by jpmml.
the class TermFeature method toWeightedTermFeature.
public WeightedTermFeature toWeightedTermFeature(double weight) {
PMMLEncoder encoder = ensureEncoder();
DefineFunction defineFunction = getDefineFunction();
String name = (defineFunction.getName()).replace("tf@", "tf-idf@");
DefineFunction weightedDefineFunction = encoder.getDefineFunction(name);
if (weightedDefineFunction == null) {
ParameterField weightField = new ParameterField(FieldName.create("weight"));
List<ParameterField> parameterFields = new ArrayList<>(defineFunction.getParameterFields());
parameterFields.add(weightField);
Apply apply = PMMLUtil.createApply("*", defineFunction.getExpression(), new FieldRef(weightField.getName()));
weightedDefineFunction = new DefineFunction(name, OpType.CONTINUOUS, parameterFields).setDataType(DataType.DOUBLE).setExpression(apply);
encoder.addDefineFunction(weightedDefineFunction);
}
return new WeightedTermFeature(encoder, weightedDefineFunction, getFeature(), getValue(), weight);
}
use of org.dmg.pmml.DefineFunction in project jpmml-sparkml by jpmml.
the class TermFeature method createApply.
public Apply createApply() {
DefineFunction defineFunction = getDefineFunction();
Feature feature = getFeature();
String value = getValue();
Constant constant = PMMLUtil.createConstant(value, DataType.STRING);
return PMMLUtil.createApply(defineFunction.getName(), feature.ref(), constant);
}
use of org.dmg.pmml.DefineFunction in project drools by kiegroup.
the class PMMLModelTestUtils method getDefineFunction.
public static DefineFunction getDefineFunction(String functionName) {
DefineFunction toReturn = new DefineFunction();
toReturn.setName(functionName);
toReturn.setDataType(getRandomDataType());
toReturn.setOpType(getRandomOpType());
Constant expression = new Constant(5);
expression.setDataType(DataType.INTEGER);
toReturn.setExpression(expression);
IntStream.range(0, 3).forEach(i -> toReturn.addParameterFields(getParameterField("ParameterField-" + i)));
return toReturn;
}
use of org.dmg.pmml.DefineFunction in project drools by kiegroup.
the class KiePMMLTextIndexNormalizationFactoryTest method setup.
@BeforeClass
public static void setup() throws Exception {
PMML pmmlModel = KiePMMLUtil.load(getFileInputStream(TRANSFORMATIONS_SAMPLE), TRANSFORMATIONS_SAMPLE);
DefineFunction definedFunction = pmmlModel.getTransformationDictionary().getDefineFunctions().stream().filter(defineFunction -> TEXT_INDEX_NORMALIZATION_FUNCTION.equals(defineFunction.getName())).findFirst().orElseThrow(() -> new RuntimeException("Missing derived field " + TEXT_INDEX_NORMALIZATION_FUNCTION));
TextIndex textIndex = ((TextIndex) definedFunction.getExpression());
TEXTINDEXNORMALIZATION = textIndex.getTextIndexNormalizations().get(0);
}
use of org.dmg.pmml.DefineFunction in project drools by kiegroup.
the class KiePMMLDefineFunctionFactoryTest method getDefineFunctionVariableDeclaration.
@Test
public void getDefineFunctionVariableDeclaration() throws IOException {
ParameterField parameterField1 = new ParameterField(FieldName.create(PARAM_1));
parameterField1.setDataType(DataType.DOUBLE);
parameterField1.setOpType(OpType.CONTINUOUS);
parameterField1.setDisplayName("displayName1");
ParameterField parameterField2 = new ParameterField(FieldName.create(PARAM_2));
parameterField2.setDataType(DataType.DOUBLE);
parameterField2.setOpType(OpType.CONTINUOUS);
parameterField2.setDisplayName("displayName2");
Constant constant = new Constant();
constant.setValue(value1);
FieldRef fieldRef = new FieldRef();
fieldRef.setField(FieldName.create("FIELD_REF"));
Apply apply = new Apply();
apply.setFunction("/");
apply.addExpressions(constant, fieldRef);
DefineFunction defineFunction = new DefineFunction();
defineFunction.setName(CUSTOM_FUNCTION);
defineFunction.addParameterFields(parameterField1, parameterField2);
defineFunction.setDataType(DataType.DOUBLE);
defineFunction.setOpType(OpType.CONTINUOUS);
defineFunction.setExpression(apply);
String dataType1 = getDATA_TYPEString(parameterField1.getDataType());
String dataType2 = getDATA_TYPEString(parameterField2.getDataType());
String dataType3 = getDATA_TYPEString(defineFunction.getDataType());
String opType1 = getOP_TYPEString(parameterField1.getOpType());
String opType2 = getOP_TYPEString(parameterField2.getOpType());
String opType3 = getOP_TYPEString(defineFunction.getOpType());
BlockStmt retrieved = KiePMMLDefineFunctionFactory.getDefineFunctionVariableDeclaration(defineFunction);
String text = getFileContent(TEST_01_SOURCE);
Statement expected = JavaParserUtils.parseBlock(String.format(text, parameterField1.getName().getValue(), dataType1, opType1, parameterField1.getDisplayName(), parameterField2.getName().getValue(), dataType2, opType2, parameterField2.getDisplayName(), constant.getValue(), fieldRef.getField().getValue(), apply.getFunction(), apply.getInvalidValueTreatment().value(), dataType3, opType3));
assertTrue(JavaParserUtils.equalsNode(expected, retrieved));
List<Class<?>> imports = Arrays.asList(KiePMMLParameterField.class, KiePMMLConstant.class, KiePMMLFieldRef.class, KiePMMLApply.class, KiePMMLDefineFunction.class, Arrays.class, Collections.class);
commonValidateCompilationWithImports(retrieved, imports);
}
Aggregations