Search in sources :

Example 1 with PartialDependenceGraph

use of org.kie.kogito.explainability.model.PartialDependenceGraph in project kogito-apps by kiegroup.

the class PartialDependencePlotExplainerTest method testPdpNumericClassifier.

@ParameterizedTest
@ValueSource(ints = { 0, 1, 2, 3, 4 })
void testPdpNumericClassifier(int seed) throws Exception {
    Random random = new Random();
    random.setSeed(seed);
    PredictionProvider modelInfo = TestUtils.getSumSkipModel(0);
    PartialDependencePlotExplainer partialDependencePlotProvider = new PartialDependencePlotExplainer();
    List<PartialDependenceGraph> pdps = partialDependencePlotProvider.explainFromMetadata(modelInfo, getMetadata(random));
    assertNotNull(pdps);
    for (PartialDependenceGraph pdp : pdps) {
        assertNotNull(pdp.getFeature());
        assertNotNull(pdp.getX());
        assertNotNull(pdp.getY());
        assertEquals(pdp.getX().size(), pdp.getY().size());
        assertGraph(pdp);
    }
    // the first feature is always skipped by the model, so the predictions are not affected, hence PDP Y values are constant
    PartialDependenceGraph fixedFeatureGraph = pdps.get(0);
    assertEquals(1, fixedFeatureGraph.getY().stream().distinct().count());
    // the other two instead vary in Y values
    assertThat(pdps.get(1).getY().stream().distinct().count()).isGreaterThan(1);
    assertThat(pdps.get(2).getY().stream().distinct().count()).isGreaterThan(1);
}
Also used : Random(java.util.Random) PredictionProvider(org.kie.kogito.explainability.model.PredictionProvider) PartialDependenceGraph(org.kie.kogito.explainability.model.PartialDependenceGraph) ValueSource(org.junit.jupiter.params.provider.ValueSource) ParameterizedTest(org.junit.jupiter.params.ParameterizedTest)

Example 2 with PartialDependenceGraph

use of org.kie.kogito.explainability.model.PartialDependenceGraph in project kogito-apps by kiegroup.

the class FraudScoringDmnPDPExplainerTest method testFraudScoringDMNExplanation.

@Test
void testFraudScoringDMNExplanation() throws ExecutionException, InterruptedException, TimeoutException {
    DMNRuntime dmnRuntime = DMNKogito.createGenericDMNRuntime(new InputStreamReader(getClass().getResourceAsStream("/dmn/fraud.dmn")));
    assertEquals(1, dmnRuntime.getModels().size());
    final String FRAUD_NS = "http://www.redhat.com/dmn/definitions/_81556584-7d78-4f8c-9d5f-b3cddb9b5c73";
    final String FRAUD_NAME = "fraud-scoring";
    DecisionModel decisionModel = new DmnDecisionModel(dmnRuntime, FRAUD_NS, FRAUD_NAME);
    PredictionProvider model = new DecisionModelWrapper(decisionModel);
    List<PredictionInput> inputs = DmnTestUtils.randomFraudScoringInputs();
    List<PredictionOutput> predictionOutputs = model.predictAsync(inputs).get(Config.INSTANCE.getAsyncTimeout(), Config.INSTANCE.getAsyncTimeUnit());
    List<Prediction> predictions = new ArrayList<>();
    for (int i = 0; i < predictionOutputs.size(); i++) {
        predictions.add(new SimplePrediction(inputs.get(i), predictionOutputs.get(i)));
    }
    PartialDependencePlotExplainer partialDependencePlotExplainer = new PartialDependencePlotExplainer();
    List<PartialDependenceGraph> pdps = partialDependencePlotExplainer.explainFromPredictions(model, predictions);
    assertThat(pdps).isNotNull();
    Assertions.assertThat(pdps).hasSize(32);
}
Also used : SimplePrediction(org.kie.kogito.explainability.model.SimplePrediction) InputStreamReader(java.io.InputStreamReader) PredictionInput(org.kie.kogito.explainability.model.PredictionInput) SimplePrediction(org.kie.kogito.explainability.model.SimplePrediction) Prediction(org.kie.kogito.explainability.model.Prediction) ArrayList(java.util.ArrayList) DecisionModel(org.kie.kogito.decision.DecisionModel) DmnDecisionModel(org.kie.kogito.dmn.DmnDecisionModel) PredictionProvider(org.kie.kogito.explainability.model.PredictionProvider) DMNRuntime(org.kie.dmn.api.core.DMNRuntime) PartialDependencePlotExplainer(org.kie.kogito.explainability.global.pdp.PartialDependencePlotExplainer) PredictionOutput(org.kie.kogito.explainability.model.PredictionOutput) DmnDecisionModel(org.kie.kogito.dmn.DmnDecisionModel) PartialDependenceGraph(org.kie.kogito.explainability.model.PartialDependenceGraph) Test(org.junit.jupiter.api.Test)

Example 3 with PartialDependenceGraph

use of org.kie.kogito.explainability.model.PartialDependenceGraph in project kogito-apps by kiegroup.

the class LoanEligibilityDmnPDPExplainerTest method testLoanEligibilityDMNExplanation.

@Test
void testLoanEligibilityDMNExplanation() throws ExecutionException, InterruptedException, TimeoutException {
    DMNRuntime dmnRuntime = DMNKogito.createGenericDMNRuntime(new InputStreamReader(getClass().getResourceAsStream("/dmn/LoanEligibility.dmn")));
    assertEquals(1, dmnRuntime.getModels().size());
    final String FRAUD_NS = "https://github.com/kiegroup/kogito-examples/dmn-quarkus-listener-example";
    final String FRAUD_NAME = "LoanEligibility";
    DecisionModel decisionModel = new DmnDecisionModel(dmnRuntime, FRAUD_NS, FRAUD_NAME);
    PredictionProvider model = new DecisionModelWrapper(decisionModel);
    List<PredictionInput> inputs = DmnTestUtils.randomLoanEligibilityInputs();
    List<PredictionOutput> predictionOutputs = model.predictAsync(inputs).get(Config.INSTANCE.getAsyncTimeout(), Config.INSTANCE.getAsyncTimeUnit());
    List<Prediction> predictions = new ArrayList<>();
    for (int i = 0; i < predictionOutputs.size(); i++) {
        predictions.add(new SimplePrediction(inputs.get(i), predictionOutputs.get(i)));
    }
    PartialDependencePlotExplainer partialDependencePlotExplainer = new PartialDependencePlotExplainer();
    List<PartialDependenceGraph> pdps = partialDependencePlotExplainer.explainFromPredictions(model, predictions);
    assertThat(pdps).isNotNull();
    Assertions.assertThat(pdps).hasSize(20);
}
Also used : SimplePrediction(org.kie.kogito.explainability.model.SimplePrediction) InputStreamReader(java.io.InputStreamReader) PredictionInput(org.kie.kogito.explainability.model.PredictionInput) SimplePrediction(org.kie.kogito.explainability.model.SimplePrediction) Prediction(org.kie.kogito.explainability.model.Prediction) ArrayList(java.util.ArrayList) DecisionModel(org.kie.kogito.decision.DecisionModel) DmnDecisionModel(org.kie.kogito.dmn.DmnDecisionModel) PredictionProvider(org.kie.kogito.explainability.model.PredictionProvider) DMNRuntime(org.kie.dmn.api.core.DMNRuntime) PartialDependencePlotExplainer(org.kie.kogito.explainability.global.pdp.PartialDependencePlotExplainer) PredictionOutput(org.kie.kogito.explainability.model.PredictionOutput) DmnDecisionModel(org.kie.kogito.dmn.DmnDecisionModel) PartialDependenceGraph(org.kie.kogito.explainability.model.PartialDependenceGraph) Test(org.junit.jupiter.api.Test)

Example 4 with PartialDependenceGraph

use of org.kie.kogito.explainability.model.PartialDependenceGraph in project kogito-apps by kiegroup.

the class OpenNLPPDPExplainerTest method testOpenNLPLangDetect.

@Test
void testOpenNLPLangDetect() throws Exception {
    PartialDependencePlotExplainer partialDependencePlotExplainer = new PartialDependencePlotExplainer();
    InputStream is = getClass().getResourceAsStream("/opennlp/langdetect-183.bin");
    LanguageDetectorModel languageDetectorModel = new LanguageDetectorModel(is);
    LanguageDetector languageDetector = new LanguageDetectorME(languageDetectorModel);
    PredictionProvider model = inputs -> CompletableFuture.supplyAsync(() -> {
        List<PredictionOutput> results = new ArrayList<>();
        for (PredictionInput predictionInput : inputs) {
            StringBuilder builder = new StringBuilder();
            for (Feature f : predictionInput.getFeatures()) {
                if (builder.length() > 0) {
                    builder.append(' ');
                }
                builder.append(f.getValue().asString());
            }
            Language language = languageDetector.predictLanguage(builder.toString());
            PredictionOutput predictionOutput = new PredictionOutput(List.of(new Output("lang", Type.TEXT, new Value(language.getLang()), language.getConfidence())));
            results.add(predictionOutput);
        }
        return results;
    });
    List<String> texts = List.of("we want your money", "please reply quickly", "you are the lucky winner", "italiani, spaghetti pizza mandolino", "guten tag", "allez les bleus", "daje roma");
    List<Prediction> predictions = new ArrayList<>();
    for (String text : texts) {
        List<Feature> features = new ArrayList<>();
        features.add(FeatureFactory.newFulltextFeature("text", text));
        PredictionInput predictionInput = new PredictionInput(features);
        PredictionOutput predictionOutput = model.predictAsync(List.of(predictionInput)).get().get(0);
        predictions.add(new SimplePrediction(predictionInput, predictionOutput));
    }
    List<PartialDependenceGraph> pdps = partialDependencePlotExplainer.explainFromPredictions(model, predictions);
    assertThat(pdps).isNotEmpty();
}
Also used : LanguageDetectorME(opennlp.tools.langdetect.LanguageDetectorME) FeatureFactory(org.kie.kogito.explainability.model.FeatureFactory) SimplePrediction(org.kie.kogito.explainability.model.SimplePrediction) LanguageDetectorModel(opennlp.tools.langdetect.LanguageDetectorModel) Feature(org.kie.kogito.explainability.model.Feature) Prediction(org.kie.kogito.explainability.model.Prediction) Assertions.assertThat(org.assertj.core.api.Assertions.assertThat) CompletableFuture(java.util.concurrent.CompletableFuture) Value(org.kie.kogito.explainability.model.Value) Type(org.kie.kogito.explainability.model.Type) PredictionProvider(org.kie.kogito.explainability.model.PredictionProvider) ArrayList(java.util.ArrayList) Test(org.junit.jupiter.api.Test) PredictionInput(org.kie.kogito.explainability.model.PredictionInput) List(java.util.List) PartialDependencePlotExplainer(org.kie.kogito.explainability.global.pdp.PartialDependencePlotExplainer) Language(opennlp.tools.langdetect.Language) LanguageDetector(opennlp.tools.langdetect.LanguageDetector) PartialDependenceGraph(org.kie.kogito.explainability.model.PartialDependenceGraph) Output(org.kie.kogito.explainability.model.Output) PredictionOutput(org.kie.kogito.explainability.model.PredictionOutput) InputStream(java.io.InputStream) SimplePrediction(org.kie.kogito.explainability.model.SimplePrediction) PredictionInput(org.kie.kogito.explainability.model.PredictionInput) InputStream(java.io.InputStream) SimplePrediction(org.kie.kogito.explainability.model.SimplePrediction) Prediction(org.kie.kogito.explainability.model.Prediction) LanguageDetectorME(opennlp.tools.langdetect.LanguageDetectorME) ArrayList(java.util.ArrayList) PredictionProvider(org.kie.kogito.explainability.model.PredictionProvider) Feature(org.kie.kogito.explainability.model.Feature) LanguageDetector(opennlp.tools.langdetect.LanguageDetector) PartialDependencePlotExplainer(org.kie.kogito.explainability.global.pdp.PartialDependencePlotExplainer) Language(opennlp.tools.langdetect.Language) LanguageDetectorModel(opennlp.tools.langdetect.LanguageDetectorModel) PredictionOutput(org.kie.kogito.explainability.model.PredictionOutput) Output(org.kie.kogito.explainability.model.Output) PredictionOutput(org.kie.kogito.explainability.model.PredictionOutput) Value(org.kie.kogito.explainability.model.Value) PartialDependenceGraph(org.kie.kogito.explainability.model.PartialDependenceGraph) Test(org.junit.jupiter.api.Test)

Example 5 with PartialDependenceGraph

use of org.kie.kogito.explainability.model.PartialDependenceGraph in project kogito-apps by kiegroup.

the class DataUtilsTest method toCSV.

@Test
void toCSV() {
    Feature feature = mock(Feature.class);
    when(feature.getName()).thenReturn("feature-1");
    Output output = mock(Output.class);
    when(output.getName()).thenReturn("decision-1");
    List<Value> x = new ArrayList<>();
    x.add(new Value(1));
    x.add(new Value(2));
    x.add(new Value(3));
    List<Value> y = new ArrayList<>();
    y.add(new Value(4));
    y.add(new Value(5));
    y.add(new Value(4));
    PartialDependenceGraph partialDependenceGraph = new PartialDependenceGraph(feature, output, x, y);
    assertDoesNotThrow(() -> DataUtils.toCSV(partialDependenceGraph, Paths.get("target/test-pdp.csv")));
}
Also used : Output(org.kie.kogito.explainability.model.Output) Value(org.kie.kogito.explainability.model.Value) ArrayList(java.util.ArrayList) Feature(org.kie.kogito.explainability.model.Feature) PartialDependenceGraph(org.kie.kogito.explainability.model.PartialDependenceGraph) Test(org.junit.jupiter.api.Test) ParameterizedTest(org.junit.jupiter.params.ParameterizedTest)

Aggregations

PartialDependenceGraph (org.kie.kogito.explainability.model.PartialDependenceGraph)10 ArrayList (java.util.ArrayList)9 PredictionInput (org.kie.kogito.explainability.model.PredictionInput)8 PredictionOutput (org.kie.kogito.explainability.model.PredictionOutput)8 PredictionProvider (org.kie.kogito.explainability.model.PredictionProvider)8 Prediction (org.kie.kogito.explainability.model.Prediction)7 Test (org.junit.jupiter.api.Test)6 SimplePrediction (org.kie.kogito.explainability.model.SimplePrediction)6 PartialDependencePlotExplainer (org.kie.kogito.explainability.global.pdp.PartialDependencePlotExplainer)5 Feature (org.kie.kogito.explainability.model.Feature)5 InputStreamReader (java.io.InputStreamReader)4 DMNRuntime (org.kie.dmn.api.core.DMNRuntime)4 DecisionModel (org.kie.kogito.decision.DecisionModel)4 DmnDecisionModel (org.kie.kogito.dmn.DmnDecisionModel)4 Output (org.kie.kogito.explainability.model.Output)4 Value (org.kie.kogito.explainability.model.Value)4 ParameterizedTest (org.junit.jupiter.params.ParameterizedTest)3 HashMap (java.util.HashMap)2 List (java.util.List)2 Map (java.util.Map)2