use of org.dmg.pmml.DefineFunction in project drools by kiegroup.
the class KiePMMLTransformationDictionaryFactoryTest method getDefineFunction.
private DefineFunction getDefineFunction(int counter) {
ParameterField parameterField1 = new ParameterField(FieldName.create(PARAM_1 + counter));
parameterField1.setDataType(DataType.DOUBLE);
parameterField1.setOpType(OpType.CONTINUOUS);
parameterField1.setDisplayName("displayName1" + counter);
ParameterField parameterField2 = new ParameterField(FieldName.create(PARAM_2 + counter));
parameterField2.setDataType(DataType.DOUBLE);
parameterField2.setOpType(OpType.CONTINUOUS);
parameterField2.setDisplayName("displayName2" + counter);
Constant constant = new Constant();
constant.setValue(value1);
FieldRef fieldRef = new FieldRef();
fieldRef.setField(FieldName.create("FIELD_REF" + counter));
Apply apply = new Apply();
apply.setFunction("/");
apply.addExpressions(constant, fieldRef);
DefineFunction toReturn = new DefineFunction();
toReturn.setName(CUSTOM_FUNCTION + counter);
toReturn.addParameterFields(parameterField1, parameterField2);
toReturn.setDataType(DataType.DOUBLE);
toReturn.setOpType(OpType.CONTINUOUS);
toReturn.setExpression(apply);
return toReturn;
}
use of org.dmg.pmml.DefineFunction in project drools by kiegroup.
the class KiePMMLDefineFunctionInstanceFactoryTest method getKiePMMLDefineFunction.
@Test
public void getKiePMMLDefineFunction() {
final String functionName = "functionName";
final DefineFunction toConvert = getDefineFunction(functionName);
KiePMMLDefineFunction retrieved = KiePMMLDefineFunctionInstanceFactory.getKiePMMLDefineFunction(toConvert);
commonVerifyKiePMMLDefineFunction(retrieved, toConvert);
}
use of org.dmg.pmml.DefineFunction in project drools by kiegroup.
the class KiePMMLTransformationDictionaryInstanceFactoryTest method getKiePMMLTransformationDictionary.
@Test
public void getKiePMMLTransformationDictionary() {
final TransformationDictionary toConvert = getRandomTransformationDictionary();
KiePMMLTransformationDictionary retrieved = KiePMMLTransformationDictionaryInstanceFactory.getKiePMMLTransformationDictionary(toConvert, Collections.emptyList());
assertNotNull(retrieved);
List<DerivedField> derivedFields = toConvert.getDerivedFields();
List<KiePMMLDerivedField> derivedFieldsToVerify = retrieved.getDerivedFields();
assertEquals(derivedFields.size(), derivedFieldsToVerify.size());
derivedFields.forEach(derivedFieldSource -> {
Optional<KiePMMLDerivedField> derivedFieldToVerify = derivedFieldsToVerify.stream().filter(param -> param.getName().equals(derivedFieldSource.getName().getValue())).findFirst();
assertTrue(derivedFieldToVerify.isPresent());
commonVerifyKiePMMLDerivedField(derivedFieldToVerify.get(), derivedFieldSource);
});
List<DefineFunction> defineFunctions = toConvert.getDefineFunctions();
List<KiePMMLDefineFunction> defineFunctionsToVerify = retrieved.getDefineFunctions();
assertEquals(defineFunctions.size(), defineFunctionsToVerify.size());
defineFunctions.forEach(defineFunctionSource -> {
Optional<KiePMMLDefineFunction> defineFunctionToVerify = defineFunctionsToVerify.stream().filter(param -> param.getName().equals(defineFunctionSource.getName())).findFirst();
assertTrue(defineFunctionToVerify.isPresent());
commonVerifyKiePMMLDefineFunction(defineFunctionToVerify.get(), defineFunctionSource);
});
}
use of org.dmg.pmml.DefineFunction in project jpmml-sparkml by jpmml.
the class CountVectorizerModelConverter method encodeFeatures.
@Override
public List<Feature> encodeFeatures(SparkMLEncoder encoder) {
CountVectorizerModel transformer = getTransformer();
DocumentFeature documentFeature = (DocumentFeature) encoder.getOnlyFeature(transformer.getInputCol());
ParameterField documentField = new ParameterField(FieldName.create("document"));
ParameterField termField = new ParameterField(FieldName.create("term"));
TextIndex textIndex = new TextIndex(documentField.getName()).setTokenize(Boolean.TRUE).setWordSeparatorCharacterRE(documentFeature.getWordSeparatorRE()).setLocalTermWeights(transformer.getBinary() ? TextIndex.LocalTermWeights.BINARY : null).setExpression(new FieldRef(termField.getName()));
Set<DocumentFeature.StopWordSet> stopWordSets = documentFeature.getStopWordSets();
for (DocumentFeature.StopWordSet stopWordSet : stopWordSets) {
if (stopWordSet.isEmpty()) {
continue;
}
DocumentBuilder documentBuilder = DOMUtil.createDocumentBuilder();
String tokenRE;
String wordSeparatorRE = documentFeature.getWordSeparatorRE();
switch(wordSeparatorRE) {
case "\\s+":
tokenRE = "(^|\\s+)\\p{Punct}*(" + JOINER.join(stopWordSet) + ")\\p{Punct}*(\\s+|$)";
break;
case "\\W+":
tokenRE = "(\\W+)(" + JOINER.join(stopWordSet) + ")(\\W+)";
break;
default:
throw new IllegalArgumentException("Expected \"\\s+\" or \"\\W+\" as splitter regex pattern, got \"" + wordSeparatorRE + "\"");
}
InlineTable inlineTable = new InlineTable().addRows(DOMUtil.createRow(documentBuilder, Arrays.asList("string", "stem", "regex"), Arrays.asList(tokenRE, " ", "true")));
TextIndexNormalization textIndexNormalization = new TextIndexNormalization().setCaseSensitive(stopWordSet.isCaseSensitive()).setRecursive(// Handles consecutive matches. See http://stackoverflow.com/a/25085385
Boolean.TRUE).setInlineTable(inlineTable);
textIndex.addTextIndexNormalizations(textIndexNormalization);
}
DefineFunction defineFunction = new DefineFunction("tf" + "@" + String.valueOf(CountVectorizerModelConverter.SEQUENCE.getAndIncrement()), OpType.CONTINUOUS, null).setDataType(DataType.INTEGER).addParameterFields(documentField, termField).setExpression(textIndex);
encoder.addDefineFunction(defineFunction);
List<Feature> result = new ArrayList<>();
String[] vocabulary = transformer.vocabulary();
for (int i = 0; i < vocabulary.length; i++) {
String term = vocabulary[i];
if (TermUtil.hasPunctuation(term)) {
throw new IllegalArgumentException(term);
}
result.add(new TermFeature(encoder, defineFunction, documentFeature, term));
}
return result;
}
use of org.dmg.pmml.DefineFunction in project drools by kiegroup.
the class KiePMMLTextIndexFactoryTest method setup.
@BeforeClass
public static void setup() throws Exception {
PMML pmmlModel = KiePMMLUtil.load(getFileInputStream(TRANSFORMATIONS_SAMPLE), TRANSFORMATIONS_SAMPLE);
DefineFunction definedFunction = pmmlModel.getTransformationDictionary().getDefineFunctions().stream().filter(defineFunction -> TEXT_INDEX_NORMALIZATION_FUNCTION.equals(defineFunction.getName())).findFirst().orElseThrow(() -> new RuntimeException("Missing derived field " + TEXT_INDEX_NORMALIZATION_FUNCTION));
TEXTINDEX = ((TextIndex) definedFunction.getExpression());
}
Aggregations