Search in sources :

Example 1 with LOCAL_TERM_WEIGHTS

use of org.kie.pmml.api.enums.LOCAL_TERM_WEIGHTS in project drools by kiegroup.

the class KiePMMLTextIndexTest method evaluateRawNoTokenize.

@Test
public void evaluateRawNoTokenize() {
    LevenshteinDistance levenshteinDistance = new LevenshteinDistance(2);
    Map<LOCAL_TERM_WEIGHTS, Double> expectedResults = new HashMap<>();
    double frequency = 3.0;
    double logarithmic = Math.log10(1 + frequency);
    int maxFrequency = 2;
    // cast
    double augmentedNormalizedTermFrequency = 0.5 * (1 + (frequency / (double) maxFrequency));
    // for java:S2184
    expectedResults.put(TERM_FREQUENCY, frequency);
    expectedResults.put(BINARY, 1.0);
    expectedResults.put(LOGARITHMIC, logarithmic);
    expectedResults.put(AUGMENTED_NORMALIZED_TERM_FREQUENCY, augmentedNormalizedTermFrequency);
    expectedResults.forEach((localTermWeights, expected) -> assertEquals(expected, KiePMMLTextIndex.evaluateRaw(true, false, TERM_0, TEXT_0, "\\s+", localTermWeights, COUNT_HITS.ALL_HITS, levenshteinDistance), 0.0000001));
    // ---
    maxFrequency = 3;
    // cast
    augmentedNormalizedTermFrequency = 0.5 * (1 + (frequency / (double) maxFrequency));
    // for java:S2184
    expectedResults = new HashMap<>();
    expectedResults.put(TERM_FREQUENCY, frequency);
    expectedResults.put(BINARY, 1.0);
    expectedResults.put(LOGARITHMIC, logarithmic);
    expectedResults.put(AUGMENTED_NORMALIZED_TERM_FREQUENCY, augmentedNormalizedTermFrequency);
    expectedResults.forEach((localTermWeights, expected) -> assertEquals(expected, KiePMMLTextIndex.evaluateRaw(false, false, TERM_0, TEXT_0, "\\s+", localTermWeights, COUNT_HITS.ALL_HITS, levenshteinDistance), 0.0000001));
    // ---
    frequency = 3.0;
    logarithmic = Math.log10(1 + frequency);
    // cast
    augmentedNormalizedTermFrequency = 0.5 * (1 + (frequency / (double) maxFrequency));
    // for java:S2184
    expectedResults = new HashMap<>();
    expectedResults.put(TERM_FREQUENCY, frequency);
    expectedResults.put(BINARY, 1.0);
    expectedResults.put(LOGARITHMIC, logarithmic);
    expectedResults.put(AUGMENTED_NORMALIZED_TERM_FREQUENCY, augmentedNormalizedTermFrequency);
    expectedResults.forEach((localTermWeights, expected) -> assertEquals(expected, KiePMMLTextIndex.evaluateRaw(false, false, TERM_0, TEXT_0, "[\\s\\-]", localTermWeights, COUNT_HITS.ALL_HITS, levenshteinDistance), 0.0000001));
}
Also used : HashMap(java.util.HashMap) LevenshteinDistance(org.apache.commons.text.similarity.LevenshteinDistance) LOCAL_TERM_WEIGHTS(org.kie.pmml.api.enums.LOCAL_TERM_WEIGHTS) Test(org.junit.Test)

Example 2 with LOCAL_TERM_WEIGHTS

use of org.kie.pmml.api.enums.LOCAL_TERM_WEIGHTS in project drools by kiegroup.

the class KiePMMLTextIndexFactory method getTextIndexVariableDeclaration.

static BlockStmt getTextIndexVariableDeclaration(final String variableName, final TextIndex textIndex) {
    final MethodDeclaration methodDeclaration = TEXTINDEX_TEMPLATE.getMethodsByName(GETKIEPMMLTEXTINDEX).get(0).clone();
    final BlockStmt textIndexBody = methodDeclaration.getBody().orElseThrow(() -> new KiePMMLException(String.format(MISSING_BODY_TEMPLATE, methodDeclaration)));
    final VariableDeclarator variableDeclarator = getVariableDeclarator(textIndexBody, TEXTINDEX).orElseThrow(() -> new KiePMMLException(String.format(MISSING_VARIABLE_IN_BODY, TEXTINDEX, textIndexBody)));
    variableDeclarator.setName(variableName);
    final BlockStmt toReturn = new BlockStmt();
    String expressionVariableName = String.format("%s_Expression", variableName);
    final BlockStmt expressionBlockStatement = getKiePMMLExpressionBlockStmt(expressionVariableName, textIndex.getExpression());
    expressionBlockStatement.getStatements().forEach(toReturn::addStatement);
    int counter = 0;
    final NodeList<Expression> arguments = new NodeList<>();
    if (textIndex.hasTextIndexNormalizations()) {
        for (TextIndexNormalization textIndexNormalization : textIndex.getTextIndexNormalizations()) {
            String nestedVariableName = String.format(VARIABLE_NAME_TEMPLATE, variableName, counter);
            arguments.add(new NameExpr(nestedVariableName));
            BlockStmt toAdd = getTextIndexNormalizationVariableDeclaration(nestedVariableName, textIndexNormalization);
            toAdd.getStatements().forEach(toReturn::addStatement);
            counter++;
        }
    }
    final MethodCallExpr initializer = variableDeclarator.getInitializer().orElseThrow(() -> new KiePMMLException(String.format(MISSING_VARIABLE_INITIALIZER_TEMPLATE, TEXTINDEX, toReturn))).asMethodCallExpr();
    final MethodCallExpr builder = getChainedMethodCallExprFrom("builder", initializer);
    final StringLiteralExpr nameExpr = new StringLiteralExpr(textIndex.getTextField().getValue());
    final NameExpr expressionExpr = new NameExpr(expressionVariableName);
    builder.setArgument(0, nameExpr);
    builder.setArgument(2, expressionExpr);
    Expression localTermWeightsExpression;
    if (textIndex.getLocalTermWeights() != null) {
        final LOCAL_TERM_WEIGHTS localTermWeights = LOCAL_TERM_WEIGHTS.byName(textIndex.getLocalTermWeights().value());
        localTermWeightsExpression = new NameExpr(LOCAL_TERM_WEIGHTS.class.getName() + "." + localTermWeights.name());
    } else {
        localTermWeightsExpression = new NullLiteralExpr();
    }
    getChainedMethodCallExprFrom("withLocalTermWeights", initializer).setArgument(0, localTermWeightsExpression);
    getChainedMethodCallExprFrom("withIsCaseSensitive", initializer).setArgument(0, getExpressionForObject(textIndex.isCaseSensitive()));
    getChainedMethodCallExprFrom("withMaxLevenshteinDistance", initializer).setArgument(0, getExpressionForObject(textIndex.getMaxLevenshteinDistance()));
    Expression countHitsExpression;
    if (textIndex.getCountHits() != null) {
        final COUNT_HITS countHits = COUNT_HITS.byName(textIndex.getCountHits().value());
        countHitsExpression = new NameExpr(COUNT_HITS.class.getName() + "." + countHits.name());
    } else {
        countHitsExpression = new NullLiteralExpr();
    }
    getChainedMethodCallExprFrom("withCountHits", initializer).setArgument(0, countHitsExpression);
    Expression wordSeparatorCharacterREExpression;
    if (textIndex.getWordSeparatorCharacterRE() != null) {
        String wordSeparatorCharacterRE = StringEscapeUtils.escapeJava(textIndex.getWordSeparatorCharacterRE());
        wordSeparatorCharacterREExpression = new StringLiteralExpr(wordSeparatorCharacterRE);
    } else {
        wordSeparatorCharacterREExpression = new NullLiteralExpr();
    }
    getChainedMethodCallExprFrom("withWordSeparatorCharacterRE", initializer).setArgument(0, wordSeparatorCharacterREExpression);
    getChainedMethodCallExprFrom("withTokenize", initializer).setArgument(0, getExpressionForObject(textIndex.isTokenize()));
    getChainedMethodCallExprFrom("asList", initializer).setArguments(arguments);
    textIndexBody.getStatements().forEach(toReturn::addStatement);
    return toReturn;
}
Also used : COUNT_HITS(org.kie.pmml.api.enums.COUNT_HITS) MethodDeclaration(com.github.javaparser.ast.body.MethodDeclaration) KiePMMLExpressionFactory.getKiePMMLExpressionBlockStmt(org.kie.pmml.compiler.commons.codegenfactories.KiePMMLExpressionFactory.getKiePMMLExpressionBlockStmt) BlockStmt(com.github.javaparser.ast.stmt.BlockStmt) NodeList(com.github.javaparser.ast.NodeList) NameExpr(com.github.javaparser.ast.expr.NameExpr) StringLiteralExpr(com.github.javaparser.ast.expr.StringLiteralExpr) LOCAL_TERM_WEIGHTS(org.kie.pmml.api.enums.LOCAL_TERM_WEIGHTS) VariableDeclarator(com.github.javaparser.ast.body.VariableDeclarator) CommonCodegenUtils.getVariableDeclarator(org.kie.pmml.compiler.commons.utils.CommonCodegenUtils.getVariableDeclarator) NullLiteralExpr(com.github.javaparser.ast.expr.NullLiteralExpr) TextIndexNormalization(org.dmg.pmml.TextIndexNormalization) Expression(com.github.javaparser.ast.expr.Expression) KiePMMLException(org.kie.pmml.api.exceptions.KiePMMLException) MethodCallExpr(com.github.javaparser.ast.expr.MethodCallExpr)

Example 3 with LOCAL_TERM_WEIGHTS

use of org.kie.pmml.api.enums.LOCAL_TERM_WEIGHTS in project drools by kiegroup.

the class KiePMMLTextIndexTest method evaluateTextIndex0Normalizations.

@Test
public void evaluateTextIndex0Normalizations() {
    // <Constant>brown fox</Constant>
    final KiePMMLConstant kiePMMLConstant = new KiePMMLConstant("NAME-1", Collections.emptyList(), TERM_0, null);
    List<KiePMMLNameValue> kiePMMLNameValues = Collections.singletonList(new KiePMMLNameValue(FIELD_NAME, NOT_NORMALIZED_TEXT_0));
    ProcessingDTO processingDTO = getProcessingDTO(kiePMMLNameValues);
    double frequency = 3.0;
    double logarithmic = Math.log10(1 + frequency);
    int maxFrequency = 2;
    // cast
    double augmentedNormalizedTermFrequency = 0.5 * (1 + (frequency / (double) maxFrequency));
    // for java:S2184
    Map<LOCAL_TERM_WEIGHTS, Double> expectedResults = new HashMap<>();
    expectedResults.put(TERM_FREQUENCY, frequency);
    expectedResults.put(BINARY, 1.0);
    expectedResults.put(LOGARITHMIC, logarithmic);
    expectedResults.put(AUGMENTED_NORMALIZED_TERM_FREQUENCY, augmentedNormalizedTermFrequency);
    expectedResults.forEach((localTermWeights, expected) -> {
        KiePMMLTextIndex kiePMMLTextIndex = KiePMMLTextIndex.builder(FIELD_NAME, Collections.emptyList(), kiePMMLConstant).withMaxLevenshteinDistance(2).withLocalTermWeights(localTermWeights).withIsCaseSensitive(true).withTextIndexNormalizations(getKiePMMLTextIndexNormalizations()).build();
        assertEquals(expected, kiePMMLTextIndex.evaluate(processingDTO));
    });
}
Also used : CommonTestingUtility.getProcessingDTO(org.kie.pmml.commons.CommonTestingUtility.getProcessingDTO) ProcessingDTO(org.kie.pmml.commons.model.ProcessingDTO) HashMap(java.util.HashMap) KiePMMLNameValue(org.kie.pmml.commons.model.tuples.KiePMMLNameValue) LOCAL_TERM_WEIGHTS(org.kie.pmml.api.enums.LOCAL_TERM_WEIGHTS) Test(org.junit.Test)

Example 4 with LOCAL_TERM_WEIGHTS

use of org.kie.pmml.api.enums.LOCAL_TERM_WEIGHTS in project drools by kiegroup.

the class KiePMMLTextIndexTest method evaluateRawTokenize.

@Test
public void evaluateRawTokenize() {
    LevenshteinDistance levenshteinDistance = new LevenshteinDistance(2);
    double frequency = 3.0;
    double logarithmic = Math.log10(1 + frequency);
    int maxFrequency = 2;
    // cast
    double augmentedNormalizedTermFrequency = 0.5 * (1 + (frequency / (double) maxFrequency));
    // for java:S2184
    Map<LOCAL_TERM_WEIGHTS, Double> expectedResults = new HashMap<>();
    expectedResults.put(TERM_FREQUENCY, frequency);
    expectedResults.put(BINARY, 1.0);
    expectedResults.put(LOGARITHMIC, logarithmic);
    expectedResults.put(AUGMENTED_NORMALIZED_TERM_FREQUENCY, augmentedNormalizedTermFrequency);
    expectedResults.forEach((localTermWeights, expected) -> assertEquals(expected, KiePMMLTextIndex.evaluateRaw(true, true, TERM_0, TEXT_0, "\\s+", localTermWeights, COUNT_HITS.ALL_HITS, levenshteinDistance), 0.0000001));
    // ---
    maxFrequency = 3;
    // cast
    augmentedNormalizedTermFrequency = 0.5 * (1 + (frequency / (double) maxFrequency));
    // for java:S2184
    expectedResults = new HashMap<>();
    expectedResults.put(TERM_FREQUENCY, frequency);
    expectedResults.put(BINARY, 1.0);
    expectedResults.put(LOGARITHMIC, logarithmic);
    expectedResults.put(AUGMENTED_NORMALIZED_TERM_FREQUENCY, augmentedNormalizedTermFrequency);
    expectedResults.forEach((localTermWeights, expected) -> assertEquals(expected, KiePMMLTextIndex.evaluateRaw(false, true, TERM_0, TEXT_0, "\\s+", localTermWeights, COUNT_HITS.ALL_HITS, levenshteinDistance), 0.0000001));
    // ---
    frequency = 4.0;
    logarithmic = Math.log10(1 + frequency);
    // cast
    augmentedNormalizedTermFrequency = 0.5 * (1 + (frequency / (double) maxFrequency));
    // for java:S2184
    expectedResults = new HashMap<>();
    expectedResults.put(TERM_FREQUENCY, frequency);
    expectedResults.put(BINARY, 1.0);
    expectedResults.put(LOGARITHMIC, logarithmic);
    expectedResults.put(AUGMENTED_NORMALIZED_TERM_FREQUENCY, augmentedNormalizedTermFrequency);
    expectedResults.forEach((localTermWeights, expected) -> assertEquals(expected, KiePMMLTextIndex.evaluateRaw(false, true, TERM_0, TEXT_0, "[\\s\\-]", localTermWeights, COUNT_HITS.ALL_HITS, levenshteinDistance), 0.0000001));
}
Also used : HashMap(java.util.HashMap) LevenshteinDistance(org.apache.commons.text.similarity.LevenshteinDistance) LOCAL_TERM_WEIGHTS(org.kie.pmml.api.enums.LOCAL_TERM_WEIGHTS) Test(org.junit.Test)

Example 5 with LOCAL_TERM_WEIGHTS

use of org.kie.pmml.api.enums.LOCAL_TERM_WEIGHTS in project drools by kiegroup.

the class KiePMMLTextIndexTest method evaluateNoTextIndex0Normalizations.

@Test
public void evaluateNoTextIndex0Normalizations() {
    // <Constant>brown fox</Constant>
    final KiePMMLConstant kiePMMLConstant = new KiePMMLConstant("NAME-1", Collections.emptyList(), TERM_0, null);
    List<KiePMMLNameValue> kiePMMLNameValues = Collections.singletonList(new KiePMMLNameValue(FIELD_NAME, TEXT_0));
    ProcessingDTO processingDTO = getProcessingDTO(kiePMMLNameValues);
    double frequency = 3.0;
    double logarithmic = Math.log10(1 + frequency);
    int maxFrequency = 2;
    // cast
    double augmentedNormalizedTermFrequency = 0.5 * (1 + (frequency / (double) maxFrequency));
    // for java:S2184
    Map<LOCAL_TERM_WEIGHTS, Double> expectedResults = new HashMap<>();
    expectedResults.put(TERM_FREQUENCY, frequency);
    expectedResults.put(BINARY, 1.0);
    expectedResults.put(LOGARITHMIC, logarithmic);
    expectedResults.put(AUGMENTED_NORMALIZED_TERM_FREQUENCY, augmentedNormalizedTermFrequency);
    expectedResults.forEach((localTermWeights, expected) -> {
        KiePMMLTextIndex kiePMMLTextIndex = KiePMMLTextIndex.builder(FIELD_NAME, Collections.emptyList(), kiePMMLConstant).withMaxLevenshteinDistance(2).withLocalTermWeights(localTermWeights).withIsCaseSensitive(true).withWordSeparatorCharacterRE("\\s+").build();
        assertEquals(expected, kiePMMLTextIndex.evaluate(processingDTO));
    });
}
Also used : CommonTestingUtility.getProcessingDTO(org.kie.pmml.commons.CommonTestingUtility.getProcessingDTO) ProcessingDTO(org.kie.pmml.commons.model.ProcessingDTO) HashMap(java.util.HashMap) KiePMMLNameValue(org.kie.pmml.commons.model.tuples.KiePMMLNameValue) LOCAL_TERM_WEIGHTS(org.kie.pmml.api.enums.LOCAL_TERM_WEIGHTS) Test(org.junit.Test)

Aggregations

LOCAL_TERM_WEIGHTS (org.kie.pmml.api.enums.LOCAL_TERM_WEIGHTS)6 HashMap (java.util.HashMap)4 Test (org.junit.Test)4 LevenshteinDistance (org.apache.commons.text.similarity.LevenshteinDistance)2 COUNT_HITS (org.kie.pmml.api.enums.COUNT_HITS)2 CommonTestingUtility.getProcessingDTO (org.kie.pmml.commons.CommonTestingUtility.getProcessingDTO)2 ProcessingDTO (org.kie.pmml.commons.model.ProcessingDTO)2 KiePMMLNameValue (org.kie.pmml.commons.model.tuples.KiePMMLNameValue)2 NodeList (com.github.javaparser.ast.NodeList)1 MethodDeclaration (com.github.javaparser.ast.body.MethodDeclaration)1 VariableDeclarator (com.github.javaparser.ast.body.VariableDeclarator)1 Expression (com.github.javaparser.ast.expr.Expression)1 MethodCallExpr (com.github.javaparser.ast.expr.MethodCallExpr)1 NameExpr (com.github.javaparser.ast.expr.NameExpr)1 NullLiteralExpr (com.github.javaparser.ast.expr.NullLiteralExpr)1 StringLiteralExpr (com.github.javaparser.ast.expr.StringLiteralExpr)1 BlockStmt (com.github.javaparser.ast.stmt.BlockStmt)1 TextIndexNormalization (org.dmg.pmml.TextIndexNormalization)1 KiePMMLException (org.kie.pmml.api.exceptions.KiePMMLException)1 KiePMMLExpressionFactory.getKiePMMLExpressionBlockStmt (org.kie.pmml.compiler.commons.codegenfactories.KiePMMLExpressionFactory.getKiePMMLExpressionBlockStmt)1