use of org.kie.kogito.explainability.model.PartialDependenceGraph in project kogito-apps by kiegroup.
the class PartialDependencePlotExplainerTest method testPdpNumericClassifier.
@ParameterizedTest
@ValueSource(ints = { 0, 1, 2, 3, 4 })
void testPdpNumericClassifier(int seed) throws Exception {
Random random = new Random();
random.setSeed(seed);
PredictionProvider modelInfo = TestUtils.getSumSkipModel(0);
PartialDependencePlotExplainer partialDependencePlotProvider = new PartialDependencePlotExplainer();
List<PartialDependenceGraph> pdps = partialDependencePlotProvider.explainFromMetadata(modelInfo, getMetadata(random));
assertNotNull(pdps);
for (PartialDependenceGraph pdp : pdps) {
assertNotNull(pdp.getFeature());
assertNotNull(pdp.getX());
assertNotNull(pdp.getY());
assertEquals(pdp.getX().size(), pdp.getY().size());
assertGraph(pdp);
}
// the first feature is always skipped by the model, so the predictions are not affected, hence PDP Y values are constant
PartialDependenceGraph fixedFeatureGraph = pdps.get(0);
assertEquals(1, fixedFeatureGraph.getY().stream().distinct().count());
// the other two instead vary in Y values
assertThat(pdps.get(1).getY().stream().distinct().count()).isGreaterThan(1);
assertThat(pdps.get(2).getY().stream().distinct().count()).isGreaterThan(1);
}
use of org.kie.kogito.explainability.model.PartialDependenceGraph in project kogito-apps by kiegroup.
the class FraudScoringDmnPDPExplainerTest method testFraudScoringDMNExplanation.
@Test
void testFraudScoringDMNExplanation() throws ExecutionException, InterruptedException, TimeoutException {
DMNRuntime dmnRuntime = DMNKogito.createGenericDMNRuntime(new InputStreamReader(getClass().getResourceAsStream("/dmn/fraud.dmn")));
assertEquals(1, dmnRuntime.getModels().size());
final String FRAUD_NS = "http://www.redhat.com/dmn/definitions/_81556584-7d78-4f8c-9d5f-b3cddb9b5c73";
final String FRAUD_NAME = "fraud-scoring";
DecisionModel decisionModel = new DmnDecisionModel(dmnRuntime, FRAUD_NS, FRAUD_NAME);
PredictionProvider model = new DecisionModelWrapper(decisionModel);
List<PredictionInput> inputs = DmnTestUtils.randomFraudScoringInputs();
List<PredictionOutput> predictionOutputs = model.predictAsync(inputs).get(Config.INSTANCE.getAsyncTimeout(), Config.INSTANCE.getAsyncTimeUnit());
List<Prediction> predictions = new ArrayList<>();
for (int i = 0; i < predictionOutputs.size(); i++) {
predictions.add(new SimplePrediction(inputs.get(i), predictionOutputs.get(i)));
}
PartialDependencePlotExplainer partialDependencePlotExplainer = new PartialDependencePlotExplainer();
List<PartialDependenceGraph> pdps = partialDependencePlotExplainer.explainFromPredictions(model, predictions);
assertThat(pdps).isNotNull();
Assertions.assertThat(pdps).hasSize(32);
}
use of org.kie.kogito.explainability.model.PartialDependenceGraph in project kogito-apps by kiegroup.
the class LoanEligibilityDmnPDPExplainerTest method testLoanEligibilityDMNExplanation.
@Test
void testLoanEligibilityDMNExplanation() throws ExecutionException, InterruptedException, TimeoutException {
DMNRuntime dmnRuntime = DMNKogito.createGenericDMNRuntime(new InputStreamReader(getClass().getResourceAsStream("/dmn/LoanEligibility.dmn")));
assertEquals(1, dmnRuntime.getModels().size());
final String FRAUD_NS = "https://github.com/kiegroup/kogito-examples/dmn-quarkus-listener-example";
final String FRAUD_NAME = "LoanEligibility";
DecisionModel decisionModel = new DmnDecisionModel(dmnRuntime, FRAUD_NS, FRAUD_NAME);
PredictionProvider model = new DecisionModelWrapper(decisionModel);
List<PredictionInput> inputs = DmnTestUtils.randomLoanEligibilityInputs();
List<PredictionOutput> predictionOutputs = model.predictAsync(inputs).get(Config.INSTANCE.getAsyncTimeout(), Config.INSTANCE.getAsyncTimeUnit());
List<Prediction> predictions = new ArrayList<>();
for (int i = 0; i < predictionOutputs.size(); i++) {
predictions.add(new SimplePrediction(inputs.get(i), predictionOutputs.get(i)));
}
PartialDependencePlotExplainer partialDependencePlotExplainer = new PartialDependencePlotExplainer();
List<PartialDependenceGraph> pdps = partialDependencePlotExplainer.explainFromPredictions(model, predictions);
assertThat(pdps).isNotNull();
Assertions.assertThat(pdps).hasSize(20);
}
use of org.kie.kogito.explainability.model.PartialDependenceGraph in project kogito-apps by kiegroup.
the class OpenNLPPDPExplainerTest method testOpenNLPLangDetect.
@Test
void testOpenNLPLangDetect() throws Exception {
PartialDependencePlotExplainer partialDependencePlotExplainer = new PartialDependencePlotExplainer();
InputStream is = getClass().getResourceAsStream("/opennlp/langdetect-183.bin");
LanguageDetectorModel languageDetectorModel = new LanguageDetectorModel(is);
LanguageDetector languageDetector = new LanguageDetectorME(languageDetectorModel);
PredictionProvider model = inputs -> CompletableFuture.supplyAsync(() -> {
List<PredictionOutput> results = new ArrayList<>();
for (PredictionInput predictionInput : inputs) {
StringBuilder builder = new StringBuilder();
for (Feature f : predictionInput.getFeatures()) {
if (builder.length() > 0) {
builder.append(' ');
}
builder.append(f.getValue().asString());
}
Language language = languageDetector.predictLanguage(builder.toString());
PredictionOutput predictionOutput = new PredictionOutput(List.of(new Output("lang", Type.TEXT, new Value(language.getLang()), language.getConfidence())));
results.add(predictionOutput);
}
return results;
});
List<String> texts = List.of("we want your money", "please reply quickly", "you are the lucky winner", "italiani, spaghetti pizza mandolino", "guten tag", "allez les bleus", "daje roma");
List<Prediction> predictions = new ArrayList<>();
for (String text : texts) {
List<Feature> features = new ArrayList<>();
features.add(FeatureFactory.newFulltextFeature("text", text));
PredictionInput predictionInput = new PredictionInput(features);
PredictionOutput predictionOutput = model.predictAsync(List.of(predictionInput)).get().get(0);
predictions.add(new SimplePrediction(predictionInput, predictionOutput));
}
List<PartialDependenceGraph> pdps = partialDependencePlotExplainer.explainFromPredictions(model, predictions);
assertThat(pdps).isNotEmpty();
}
use of org.kie.kogito.explainability.model.PartialDependenceGraph in project kogito-apps by kiegroup.
the class DataUtilsTest method toCSV.
@Test
void toCSV() {
Feature feature = mock(Feature.class);
when(feature.getName()).thenReturn("feature-1");
Output output = mock(Output.class);
when(output.getName()).thenReturn("decision-1");
List<Value> x = new ArrayList<>();
x.add(new Value(1));
x.add(new Value(2));
x.add(new Value(3));
List<Value> y = new ArrayList<>();
y.add(new Value(4));
y.add(new Value(5));
y.add(new Value(4));
PartialDependenceGraph partialDependenceGraph = new PartialDependenceGraph(feature, output, x, y);
assertDoesNotThrow(() -> DataUtils.toCSV(partialDependenceGraph, Paths.get("target/test-pdp.csv")));
}
Aggregations