use of org.kie.pmml.api.enums.LOCAL_TERM_WEIGHTS in project drools by kiegroup.
the class KiePMMLTextIndexTest method evaluateRawNoTokenize.
@Test
public void evaluateRawNoTokenize() {
LevenshteinDistance levenshteinDistance = new LevenshteinDistance(2);
Map<LOCAL_TERM_WEIGHTS, Double> expectedResults = new HashMap<>();
double frequency = 3.0;
double logarithmic = Math.log10(1 + frequency);
int maxFrequency = 2;
// cast
double augmentedNormalizedTermFrequency = 0.5 * (1 + (frequency / (double) maxFrequency));
// for java:S2184
expectedResults.put(TERM_FREQUENCY, frequency);
expectedResults.put(BINARY, 1.0);
expectedResults.put(LOGARITHMIC, logarithmic);
expectedResults.put(AUGMENTED_NORMALIZED_TERM_FREQUENCY, augmentedNormalizedTermFrequency);
expectedResults.forEach((localTermWeights, expected) -> assertEquals(expected, KiePMMLTextIndex.evaluateRaw(true, false, TERM_0, TEXT_0, "\\s+", localTermWeights, COUNT_HITS.ALL_HITS, levenshteinDistance), 0.0000001));
// ---
maxFrequency = 3;
// cast
augmentedNormalizedTermFrequency = 0.5 * (1 + (frequency / (double) maxFrequency));
// for java:S2184
expectedResults = new HashMap<>();
expectedResults.put(TERM_FREQUENCY, frequency);
expectedResults.put(BINARY, 1.0);
expectedResults.put(LOGARITHMIC, logarithmic);
expectedResults.put(AUGMENTED_NORMALIZED_TERM_FREQUENCY, augmentedNormalizedTermFrequency);
expectedResults.forEach((localTermWeights, expected) -> assertEquals(expected, KiePMMLTextIndex.evaluateRaw(false, false, TERM_0, TEXT_0, "\\s+", localTermWeights, COUNT_HITS.ALL_HITS, levenshteinDistance), 0.0000001));
// ---
frequency = 3.0;
logarithmic = Math.log10(1 + frequency);
// cast
augmentedNormalizedTermFrequency = 0.5 * (1 + (frequency / (double) maxFrequency));
// for java:S2184
expectedResults = new HashMap<>();
expectedResults.put(TERM_FREQUENCY, frequency);
expectedResults.put(BINARY, 1.0);
expectedResults.put(LOGARITHMIC, logarithmic);
expectedResults.put(AUGMENTED_NORMALIZED_TERM_FREQUENCY, augmentedNormalizedTermFrequency);
expectedResults.forEach((localTermWeights, expected) -> assertEquals(expected, KiePMMLTextIndex.evaluateRaw(false, false, TERM_0, TEXT_0, "[\\s\\-]", localTermWeights, COUNT_HITS.ALL_HITS, levenshteinDistance), 0.0000001));
}
use of org.kie.pmml.api.enums.LOCAL_TERM_WEIGHTS in project drools by kiegroup.
the class KiePMMLTextIndexFactory method getTextIndexVariableDeclaration.
static BlockStmt getTextIndexVariableDeclaration(final String variableName, final TextIndex textIndex) {
final MethodDeclaration methodDeclaration = TEXTINDEX_TEMPLATE.getMethodsByName(GETKIEPMMLTEXTINDEX).get(0).clone();
final BlockStmt textIndexBody = methodDeclaration.getBody().orElseThrow(() -> new KiePMMLException(String.format(MISSING_BODY_TEMPLATE, methodDeclaration)));
final VariableDeclarator variableDeclarator = getVariableDeclarator(textIndexBody, TEXTINDEX).orElseThrow(() -> new KiePMMLException(String.format(MISSING_VARIABLE_IN_BODY, TEXTINDEX, textIndexBody)));
variableDeclarator.setName(variableName);
final BlockStmt toReturn = new BlockStmt();
String expressionVariableName = String.format("%s_Expression", variableName);
final BlockStmt expressionBlockStatement = getKiePMMLExpressionBlockStmt(expressionVariableName, textIndex.getExpression());
expressionBlockStatement.getStatements().forEach(toReturn::addStatement);
int counter = 0;
final NodeList<Expression> arguments = new NodeList<>();
if (textIndex.hasTextIndexNormalizations()) {
for (TextIndexNormalization textIndexNormalization : textIndex.getTextIndexNormalizations()) {
String nestedVariableName = String.format(VARIABLE_NAME_TEMPLATE, variableName, counter);
arguments.add(new NameExpr(nestedVariableName));
BlockStmt toAdd = getTextIndexNormalizationVariableDeclaration(nestedVariableName, textIndexNormalization);
toAdd.getStatements().forEach(toReturn::addStatement);
counter++;
}
}
final MethodCallExpr initializer = variableDeclarator.getInitializer().orElseThrow(() -> new KiePMMLException(String.format(MISSING_VARIABLE_INITIALIZER_TEMPLATE, TEXTINDEX, toReturn))).asMethodCallExpr();
final MethodCallExpr builder = getChainedMethodCallExprFrom("builder", initializer);
final StringLiteralExpr nameExpr = new StringLiteralExpr(textIndex.getTextField().getValue());
final NameExpr expressionExpr = new NameExpr(expressionVariableName);
builder.setArgument(0, nameExpr);
builder.setArgument(2, expressionExpr);
Expression localTermWeightsExpression;
if (textIndex.getLocalTermWeights() != null) {
final LOCAL_TERM_WEIGHTS localTermWeights = LOCAL_TERM_WEIGHTS.byName(textIndex.getLocalTermWeights().value());
localTermWeightsExpression = new NameExpr(LOCAL_TERM_WEIGHTS.class.getName() + "." + localTermWeights.name());
} else {
localTermWeightsExpression = new NullLiteralExpr();
}
getChainedMethodCallExprFrom("withLocalTermWeights", initializer).setArgument(0, localTermWeightsExpression);
getChainedMethodCallExprFrom("withIsCaseSensitive", initializer).setArgument(0, getExpressionForObject(textIndex.isCaseSensitive()));
getChainedMethodCallExprFrom("withMaxLevenshteinDistance", initializer).setArgument(0, getExpressionForObject(textIndex.getMaxLevenshteinDistance()));
Expression countHitsExpression;
if (textIndex.getCountHits() != null) {
final COUNT_HITS countHits = COUNT_HITS.byName(textIndex.getCountHits().value());
countHitsExpression = new NameExpr(COUNT_HITS.class.getName() + "." + countHits.name());
} else {
countHitsExpression = new NullLiteralExpr();
}
getChainedMethodCallExprFrom("withCountHits", initializer).setArgument(0, countHitsExpression);
Expression wordSeparatorCharacterREExpression;
if (textIndex.getWordSeparatorCharacterRE() != null) {
String wordSeparatorCharacterRE = StringEscapeUtils.escapeJava(textIndex.getWordSeparatorCharacterRE());
wordSeparatorCharacterREExpression = new StringLiteralExpr(wordSeparatorCharacterRE);
} else {
wordSeparatorCharacterREExpression = new NullLiteralExpr();
}
getChainedMethodCallExprFrom("withWordSeparatorCharacterRE", initializer).setArgument(0, wordSeparatorCharacterREExpression);
getChainedMethodCallExprFrom("withTokenize", initializer).setArgument(0, getExpressionForObject(textIndex.isTokenize()));
getChainedMethodCallExprFrom("asList", initializer).setArguments(arguments);
textIndexBody.getStatements().forEach(toReturn::addStatement);
return toReturn;
}
use of org.kie.pmml.api.enums.LOCAL_TERM_WEIGHTS in project drools by kiegroup.
the class KiePMMLTextIndexTest method evaluateTextIndex0Normalizations.
@Test
public void evaluateTextIndex0Normalizations() {
// <Constant>brown fox</Constant>
final KiePMMLConstant kiePMMLConstant = new KiePMMLConstant("NAME-1", Collections.emptyList(), TERM_0, null);
List<KiePMMLNameValue> kiePMMLNameValues = Collections.singletonList(new KiePMMLNameValue(FIELD_NAME, NOT_NORMALIZED_TEXT_0));
ProcessingDTO processingDTO = getProcessingDTO(kiePMMLNameValues);
double frequency = 3.0;
double logarithmic = Math.log10(1 + frequency);
int maxFrequency = 2;
// cast
double augmentedNormalizedTermFrequency = 0.5 * (1 + (frequency / (double) maxFrequency));
// for java:S2184
Map<LOCAL_TERM_WEIGHTS, Double> expectedResults = new HashMap<>();
expectedResults.put(TERM_FREQUENCY, frequency);
expectedResults.put(BINARY, 1.0);
expectedResults.put(LOGARITHMIC, logarithmic);
expectedResults.put(AUGMENTED_NORMALIZED_TERM_FREQUENCY, augmentedNormalizedTermFrequency);
expectedResults.forEach((localTermWeights, expected) -> {
KiePMMLTextIndex kiePMMLTextIndex = KiePMMLTextIndex.builder(FIELD_NAME, Collections.emptyList(), kiePMMLConstant).withMaxLevenshteinDistance(2).withLocalTermWeights(localTermWeights).withIsCaseSensitive(true).withTextIndexNormalizations(getKiePMMLTextIndexNormalizations()).build();
assertEquals(expected, kiePMMLTextIndex.evaluate(processingDTO));
});
}
use of org.kie.pmml.api.enums.LOCAL_TERM_WEIGHTS in project drools by kiegroup.
the class KiePMMLTextIndexTest method evaluateRawTokenize.
@Test
public void evaluateRawTokenize() {
LevenshteinDistance levenshteinDistance = new LevenshteinDistance(2);
double frequency = 3.0;
double logarithmic = Math.log10(1 + frequency);
int maxFrequency = 2;
// cast
double augmentedNormalizedTermFrequency = 0.5 * (1 + (frequency / (double) maxFrequency));
// for java:S2184
Map<LOCAL_TERM_WEIGHTS, Double> expectedResults = new HashMap<>();
expectedResults.put(TERM_FREQUENCY, frequency);
expectedResults.put(BINARY, 1.0);
expectedResults.put(LOGARITHMIC, logarithmic);
expectedResults.put(AUGMENTED_NORMALIZED_TERM_FREQUENCY, augmentedNormalizedTermFrequency);
expectedResults.forEach((localTermWeights, expected) -> assertEquals(expected, KiePMMLTextIndex.evaluateRaw(true, true, TERM_0, TEXT_0, "\\s+", localTermWeights, COUNT_HITS.ALL_HITS, levenshteinDistance), 0.0000001));
// ---
maxFrequency = 3;
// cast
augmentedNormalizedTermFrequency = 0.5 * (1 + (frequency / (double) maxFrequency));
// for java:S2184
expectedResults = new HashMap<>();
expectedResults.put(TERM_FREQUENCY, frequency);
expectedResults.put(BINARY, 1.0);
expectedResults.put(LOGARITHMIC, logarithmic);
expectedResults.put(AUGMENTED_NORMALIZED_TERM_FREQUENCY, augmentedNormalizedTermFrequency);
expectedResults.forEach((localTermWeights, expected) -> assertEquals(expected, KiePMMLTextIndex.evaluateRaw(false, true, TERM_0, TEXT_0, "\\s+", localTermWeights, COUNT_HITS.ALL_HITS, levenshteinDistance), 0.0000001));
// ---
frequency = 4.0;
logarithmic = Math.log10(1 + frequency);
// cast
augmentedNormalizedTermFrequency = 0.5 * (1 + (frequency / (double) maxFrequency));
// for java:S2184
expectedResults = new HashMap<>();
expectedResults.put(TERM_FREQUENCY, frequency);
expectedResults.put(BINARY, 1.0);
expectedResults.put(LOGARITHMIC, logarithmic);
expectedResults.put(AUGMENTED_NORMALIZED_TERM_FREQUENCY, augmentedNormalizedTermFrequency);
expectedResults.forEach((localTermWeights, expected) -> assertEquals(expected, KiePMMLTextIndex.evaluateRaw(false, true, TERM_0, TEXT_0, "[\\s\\-]", localTermWeights, COUNT_HITS.ALL_HITS, levenshteinDistance), 0.0000001));
}
use of org.kie.pmml.api.enums.LOCAL_TERM_WEIGHTS in project drools by kiegroup.
the class KiePMMLTextIndexTest method evaluateNoTextIndex0Normalizations.
@Test
public void evaluateNoTextIndex0Normalizations() {
// <Constant>brown fox</Constant>
final KiePMMLConstant kiePMMLConstant = new KiePMMLConstant("NAME-1", Collections.emptyList(), TERM_0, null);
List<KiePMMLNameValue> kiePMMLNameValues = Collections.singletonList(new KiePMMLNameValue(FIELD_NAME, TEXT_0));
ProcessingDTO processingDTO = getProcessingDTO(kiePMMLNameValues);
double frequency = 3.0;
double logarithmic = Math.log10(1 + frequency);
int maxFrequency = 2;
// cast
double augmentedNormalizedTermFrequency = 0.5 * (1 + (frequency / (double) maxFrequency));
// for java:S2184
Map<LOCAL_TERM_WEIGHTS, Double> expectedResults = new HashMap<>();
expectedResults.put(TERM_FREQUENCY, frequency);
expectedResults.put(BINARY, 1.0);
expectedResults.put(LOGARITHMIC, logarithmic);
expectedResults.put(AUGMENTED_NORMALIZED_TERM_FREQUENCY, augmentedNormalizedTermFrequency);
expectedResults.forEach((localTermWeights, expected) -> {
KiePMMLTextIndex kiePMMLTextIndex = KiePMMLTextIndex.builder(FIELD_NAME, Collections.emptyList(), kiePMMLConstant).withMaxLevenshteinDistance(2).withLocalTermWeights(localTermWeights).withIsCaseSensitive(true).withWordSeparatorCharacterRE("\\s+").build();
assertEquals(expected, kiePMMLTextIndex.evaluate(processingDTO));
});
}
Aggregations