Search in sources :

Example 6 with LOCAL_TERM_WEIGHTS

use of org.kie.pmml.api.enums.LOCAL_TERM_WEIGHTS in project drools by kiegroup.

the class KiePMMLTextIndexTest method evaluateNoTextIndex0Normalizations.

@Test
public void evaluateNoTextIndex0Normalizations() {
    // <Constant>brown fox</Constant>
    final KiePMMLConstant kiePMMLConstant = new KiePMMLConstant("NAME-1", Collections.emptyList(), TERM_0, null);
    List<KiePMMLNameValue> kiePMMLNameValues = Collections.singletonList(new KiePMMLNameValue(FIELD_NAME, TEXT_0));
    ProcessingDTO processingDTO = getProcessingDTO(kiePMMLNameValues);
    double frequency = 3.0;
    double logarithmic = Math.log10(1 + frequency);
    int maxFrequency = 2;
    // cast
    double augmentedNormalizedTermFrequency = 0.5 * (1 + (frequency / (double) maxFrequency));
    // for java:S2184
    Map<LOCAL_TERM_WEIGHTS, Double> expectedResults = new HashMap<>();
    expectedResults.put(TERM_FREQUENCY, frequency);
    expectedResults.put(BINARY, 1.0);
    expectedResults.put(LOGARITHMIC, logarithmic);
    expectedResults.put(AUGMENTED_NORMALIZED_TERM_FREQUENCY, augmentedNormalizedTermFrequency);
    expectedResults.forEach((localTermWeights, expected) -> {
        KiePMMLTextIndex kiePMMLTextIndex = KiePMMLTextIndex.builder(FIELD_NAME, Collections.emptyList(), kiePMMLConstant).withMaxLevenshteinDistance(2).withLocalTermWeights(localTermWeights).withIsCaseSensitive(true).withWordSeparatorCharacterRE("\\s+").build();
        assertEquals(expected, kiePMMLTextIndex.evaluate(processingDTO));
    });
}
Also used : CommonTestingUtility.getProcessingDTO(org.kie.pmml.commons.CommonTestingUtility.getProcessingDTO) ProcessingDTO(org.kie.pmml.commons.model.ProcessingDTO) HashMap(java.util.HashMap) KiePMMLNameValue(org.kie.pmml.commons.model.tuples.KiePMMLNameValue) LOCAL_TERM_WEIGHTS(org.kie.pmml.api.enums.LOCAL_TERM_WEIGHTS) Test(org.junit.Test)

Aggregations

LOCAL_TERM_WEIGHTS (org.kie.pmml.api.enums.LOCAL_TERM_WEIGHTS)6 HashMap (java.util.HashMap)4 Test (org.junit.Test)4 LevenshteinDistance (org.apache.commons.text.similarity.LevenshteinDistance)2 COUNT_HITS (org.kie.pmml.api.enums.COUNT_HITS)2 CommonTestingUtility.getProcessingDTO (org.kie.pmml.commons.CommonTestingUtility.getProcessingDTO)2 ProcessingDTO (org.kie.pmml.commons.model.ProcessingDTO)2 KiePMMLNameValue (org.kie.pmml.commons.model.tuples.KiePMMLNameValue)2 NodeList (com.github.javaparser.ast.NodeList)1 MethodDeclaration (com.github.javaparser.ast.body.MethodDeclaration)1 VariableDeclarator (com.github.javaparser.ast.body.VariableDeclarator)1 Expression (com.github.javaparser.ast.expr.Expression)1 MethodCallExpr (com.github.javaparser.ast.expr.MethodCallExpr)1 NameExpr (com.github.javaparser.ast.expr.NameExpr)1 NullLiteralExpr (com.github.javaparser.ast.expr.NullLiteralExpr)1 StringLiteralExpr (com.github.javaparser.ast.expr.StringLiteralExpr)1 BlockStmt (com.github.javaparser.ast.stmt.BlockStmt)1 TextIndexNormalization (org.dmg.pmml.TextIndexNormalization)1 KiePMMLException (org.kie.pmml.api.exceptions.KiePMMLException)1 KiePMMLExpressionFactory.getKiePMMLExpressionBlockStmt (org.kie.pmml.compiler.commons.codegenfactories.KiePMMLExpressionFactory.getKiePMMLExpressionBlockStmt)1