use of org.kie.kogito.explainability.model.Output in project kogito-apps by kiegroup.
the class TestUtils method getLinearModel.
public static PredictionProvider getLinearModel(double[] weights) {
return inputs -> supplyAsync(() -> {
List<PredictionOutput> predictionOutputs = new LinkedList<>();
for (PredictionInput predictionInput : inputs) {
List<Feature> features = predictionInput.getFeatures();
double result = 0;
for (int i = 0; i < features.size(); i++) {
result += features.get(i).getValue().asNumber() * weights[i];
}
PredictionOutput predictionOutput = new PredictionOutput(List.of(new Output("linear-sum", Type.NUMBER, new Value(result), 1d)));
predictionOutputs.add(predictionOutput);
}
return predictionOutputs;
});
}
use of org.kie.kogito.explainability.model.Output in project kogito-apps by kiegroup.
the class TestUtils method getEvenSumModel.
public static PredictionProvider getEvenSumModel(int skipFeatureIndex) {
return inputs -> supplyAsync(() -> {
List<PredictionOutput> predictionOutputs = new LinkedList<>();
for (PredictionInput predictionInput : inputs) {
List<Feature> features = predictionInput.getFeatures();
double result = 0;
for (int i = 0; i < features.size(); i++) {
if (skipFeatureIndex != i) {
result += features.get(i).getValue().asNumber();
}
}
PredictionOutput predictionOutput = new PredictionOutput(List.of(new Output("sum-even-but" + skipFeatureIndex, Type.BOOLEAN, new Value(((int) result) % 2 == 0), 1d)));
predictionOutputs.add(predictionOutput);
}
return predictionOutputs;
});
}
use of org.kie.kogito.explainability.model.Output in project kogito-apps by kiegroup.
the class TestUtils method getDummyTextClassifier.
public static PredictionProvider getDummyTextClassifier() {
List<String> blackList = Arrays.asList("money", "$", "£", "bitcoin");
return inputs -> supplyAsync(() -> {
List<PredictionOutput> outputs = new LinkedList<>();
for (PredictionInput input : inputs) {
boolean spam = false;
for (Feature f : input.getFeatures()) {
if (!spam) {
String s = f.getValue().asString();
String[] words = s.split(" ");
for (String w : words) {
if (blackList.contains(w)) {
spam = true;
break;
}
}
}
}
Output output = new Output("spam", Type.BOOLEAN, new Value(spam), 1d);
outputs.add(new PredictionOutput(List.of(output)));
}
return outputs;
});
}
use of org.kie.kogito.explainability.model.Output in project kogito-apps by kiegroup.
the class TestUtils method getFeaturePassModel.
public static PredictionProvider getFeaturePassModel(int featureIndex) {
return inputs -> supplyAsync(() -> {
List<PredictionOutput> predictionOutputs = new LinkedList<>();
for (PredictionInput predictionInput : inputs) {
List<Feature> features = predictionInput.getFeatures();
Feature feature = features.get(featureIndex);
PredictionOutput predictionOutput = new PredictionOutput(List.of(new Output("feature-" + featureIndex, feature.getType(), feature.getValue(), 1d)));
predictionOutputs.add(predictionOutput);
}
return predictionOutputs;
});
}
use of org.kie.kogito.explainability.model.Output in project kogito-apps by kiegroup.
the class TestUtils method getFeatureSkipModel.
/**
* Test model which returns the inputs as outputs, except for a single specified feature
*
* @param featureIndex Index of the input feature to omit from output
* @return A {@link PredictionProvider} model
*/
public static PredictionProvider getFeatureSkipModel(int featureIndex) {
return inputs -> supplyAsync(() -> {
List<PredictionOutput> predictionOutputs = new LinkedList<>();
for (PredictionInput predictionInput : inputs) {
List<Feature> features = predictionInput.getFeatures();
List<Output> outputs = new ArrayList<>();
for (int i = 0; i < features.size(); i++) {
if (i != featureIndex) {
Feature feature = features.get(i);
outputs.add(new Output(feature.getName(), feature.getType(), feature.getValue(), 1.0));
}
}
PredictionOutput predictionOutput = new PredictionOutput(outputs);
predictionOutputs.add(predictionOutput);
}
return predictionOutputs;
});
}
Aggregations