Search in sources :

Example 6 with PartialDependenceGraph

use of org.kie.kogito.explainability.model.PartialDependenceGraph in project kogito-apps by kiegroup.

the class PartialDependencePlotExplainerTest method testTextClassifier.

@ParameterizedTest
@ValueSource(ints = { 0, 1, 2, 3, 4 })
void testTextClassifier(int seed) throws Exception {
    Random random = new Random();
    random.setSeed(seed);
    PartialDependencePlotExplainer partialDependencePlotExplainer = new PartialDependencePlotExplainer();
    PredictionProvider model = TestUtils.getDummyTextClassifier();
    Collection<Prediction> predictions = new ArrayList<>(3);
    List<String> texts = List.of("we want your money", "please reply quickly", "you are the lucky winner", "huge donation for you!", "bitcoin for you");
    for (String text : texts) {
        List<Feature> features = new ArrayList<>();
        features.add(FeatureFactory.newFulltextFeature("text", text));
        PredictionInput predictionInput = new PredictionInput(features);
        PredictionOutput predictionOutput = model.predictAsync(List.of(predictionInput)).get().get(0);
        predictions.add(new SimplePrediction(predictionInput, predictionOutput));
    }
    List<PartialDependenceGraph> pdps = partialDependencePlotExplainer.explainFromPredictions(model, predictions);
    assertThat(pdps).isNotEmpty();
}
Also used : SimplePrediction(org.kie.kogito.explainability.model.SimplePrediction) PredictionInput(org.kie.kogito.explainability.model.PredictionInput) Prediction(org.kie.kogito.explainability.model.Prediction) SimplePrediction(org.kie.kogito.explainability.model.SimplePrediction) ArrayList(java.util.ArrayList) PredictionProvider(org.kie.kogito.explainability.model.PredictionProvider) Feature(org.kie.kogito.explainability.model.Feature) Random(java.util.Random) PredictionOutput(org.kie.kogito.explainability.model.PredictionOutput) PartialDependenceGraph(org.kie.kogito.explainability.model.PartialDependenceGraph) ValueSource(org.junit.jupiter.params.provider.ValueSource) ParameterizedTest(org.junit.jupiter.params.ParameterizedTest)

Example 7 with PartialDependenceGraph

use of org.kie.kogito.explainability.model.PartialDependenceGraph in project kogito-apps by kiegroup.

the class PartialDependencePlotExplainer method getPartialDependenceGraph.

private PartialDependenceGraph getPartialDependenceGraph(PredictionProvider model, List<PredictionInput> trainingData, List<Value> xsValues, List<Feature> featureXSvalues, int outputIndex) throws InterruptedException, ExecutionException, TimeoutException {
    Output outputDecision = null;
    Feature feature = null;
    // each feature value of the feature under analysis should have a corresponding output value (composed by the marginal impacts of the other features)
    List<Map<Value, Long>> valueCounts = new ArrayList<>(featureXSvalues.size());
    for (int i = 0; i < featureXSvalues.size(); i++) {
        // initialize an empty feature to use in the generated PDP
        if (feature == null) {
            feature = FeatureFactory.copyOf(featureXSvalues.get(i), new Value(null));
        }
        List<PredictionInput> predictionInputs = prepareInputs(featureXSvalues.get(i), trainingData);
        List<PredictionOutput> predictionOutputs = getOutputs(model, predictionInputs);
        // prediction requests are batched per value of feature 'Xs' under analysis
        for (PredictionOutput predictionOutput : predictionOutputs) {
            Output output = predictionOutput.getOutputs().get(outputIndex);
            if (outputDecision == null) {
                outputDecision = new Output(output.getName(), output.getType());
            }
            // update output value counts
            updateValueCounts(valueCounts, i, output);
        }
    }
    if (outputDecision != null) {
        List<Value> yValues = collapseMarginalImpacts(valueCounts, outputDecision.getType());
        return new PartialDependenceGraph(feature, outputDecision, xsValues, yValues);
    } else {
        throw new IllegalArgumentException("cannot produce PDP for null decision");
    }
}
Also used : PredictionInput(org.kie.kogito.explainability.model.PredictionInput) ArrayList(java.util.ArrayList) Feature(org.kie.kogito.explainability.model.Feature) PredictionOutput(org.kie.kogito.explainability.model.PredictionOutput) PredictionOutput(org.kie.kogito.explainability.model.PredictionOutput) Output(org.kie.kogito.explainability.model.Output) Value(org.kie.kogito.explainability.model.Value) HashMap(java.util.HashMap) Map(java.util.Map) PartialDependenceGraph(org.kie.kogito.explainability.model.PartialDependenceGraph)

Example 8 with PartialDependenceGraph

use of org.kie.kogito.explainability.model.PartialDependenceGraph in project kogito-apps by kiegroup.

the class PartialDependencePlotExplainer method explainFromDataDistribution.

private List<PartialDependenceGraph> explainFromDataDistribution(PredictionProvider model, int outputSize, DataDistribution dataDistribution) throws InterruptedException, ExecutionException, TimeoutException {
    long start = System.currentTimeMillis();
    List<PartialDependenceGraph> pdps = new ArrayList<>();
    List<FeatureDistribution> featureDistributions = dataDistribution.asFeatureDistributions();
    // fetch entire data distributions for all features
    List<PredictionInput> trainingData = dataDistribution.sample(config.getSeriesLength());
    // create a PDP for each feature
    for (FeatureDistribution featureDistribution : featureDistributions) {
        // generate (further) samples for the feature under analysis
        // TBD: maybe just reuse trainingData
        List<Value> xsValues = featureDistribution.sample(config.getSeriesLength()).stream().sorted(// sort alphanumerically (if Value#asNumber is NaN)
        Comparator.comparing(Value::asString)).sorted(// sort by natural order
        (v1, v2) -> Comparator.comparingDouble(Value::asNumber).compare(v1, v2)).distinct().collect(Collectors.toList());
        List<Feature> featureXSvalues = // transform sampled Values into Features
        xsValues.stream().map(v -> FeatureFactory.copyOf(featureDistribution.getFeature(), v)).collect(Collectors.toList());
        // create a PDP for each feature and each output
        for (int outputIndex = 0; outputIndex < outputSize; outputIndex++) {
            PartialDependenceGraph partialDependenceGraph = getPartialDependenceGraph(model, trainingData, xsValues, featureXSvalues, outputIndex);
            pdps.add(partialDependenceGraph);
        }
    }
    long end = System.currentTimeMillis();
    LOGGER.debug("explanation time: {}ms", (end - start));
    return pdps;
}
Also used : FeatureFactory(org.kie.kogito.explainability.model.FeatureFactory) PredictionInputsDataDistribution(org.kie.kogito.explainability.model.PredictionInputsDataDistribution) Feature(org.kie.kogito.explainability.model.Feature) Prediction(org.kie.kogito.explainability.model.Prediction) LoggerFactory(org.slf4j.LoggerFactory) TimeoutException(java.util.concurrent.TimeoutException) HashMap(java.util.HashMap) Value(org.kie.kogito.explainability.model.Value) DataDistribution(org.kie.kogito.explainability.model.DataDistribution) ArrayList(java.util.ArrayList) PartialDependenceGraph(org.kie.kogito.explainability.model.PartialDependenceGraph) Map(java.util.Map) FeatureDistribution(org.kie.kogito.explainability.model.FeatureDistribution) GlobalExplainer(org.kie.kogito.explainability.global.GlobalExplainer) PredictionOutput(org.kie.kogito.explainability.model.PredictionOutput) Logger(org.slf4j.Logger) Collection(java.util.Collection) Collectors(java.util.stream.Collectors) Type(org.kie.kogito.explainability.model.Type) PredictionProvider(org.kie.kogito.explainability.model.PredictionProvider) PredictionProviderMetadata(org.kie.kogito.explainability.model.PredictionProviderMetadata) ExecutionException(java.util.concurrent.ExecutionException) PredictionInput(org.kie.kogito.explainability.model.PredictionInput) List(java.util.List) Output(org.kie.kogito.explainability.model.Output) Comparator(java.util.Comparator) Config(org.kie.kogito.explainability.Config) PredictionInput(org.kie.kogito.explainability.model.PredictionInput) ArrayList(java.util.ArrayList) Feature(org.kie.kogito.explainability.model.Feature) FeatureDistribution(org.kie.kogito.explainability.model.FeatureDistribution) Value(org.kie.kogito.explainability.model.Value) PartialDependenceGraph(org.kie.kogito.explainability.model.PartialDependenceGraph)

Example 9 with PartialDependenceGraph

use of org.kie.kogito.explainability.model.PartialDependenceGraph in project kogito-apps by kiegroup.

the class TrafficViolationDmnPDPExplainerTest method testTrafficViolationDMNExplanation.

@Test
void testTrafficViolationDMNExplanation() throws ExecutionException, InterruptedException, TimeoutException {
    DMNRuntime dmnRuntime = DMNKogito.createGenericDMNRuntime(new InputStreamReader(getClass().getResourceAsStream("/dmn/TrafficViolation.dmn")));
    assertEquals(1, dmnRuntime.getModels().size());
    final String TRAFFIC_VIOLATION_NS = "https://github.com/kiegroup/drools/kie-dmn/_A4BCA8B8-CF08-433F-93B2-A2598F19ECFF";
    final String TRAFFIC_VIOLATION_NAME = "Traffic Violation";
    DecisionModel decisionModel = new DmnDecisionModel(dmnRuntime, TRAFFIC_VIOLATION_NS, TRAFFIC_VIOLATION_NAME);
    PredictionProvider model = new DecisionModelWrapper(decisionModel);
    List<PredictionInput> inputs = DmnTestUtils.randomTrafficViolationInputs();
    List<PredictionOutput> predictionOutputs = model.predictAsync(inputs).get(Config.INSTANCE.getAsyncTimeout(), Config.INSTANCE.getAsyncTimeUnit());
    List<Prediction> predictions = new ArrayList<>();
    for (int i = 0; i < predictionOutputs.size(); i++) {
        predictions.add(new SimplePrediction(inputs.get(i), predictionOutputs.get(i)));
    }
    PartialDependencePlotExplainer partialDependencePlotExplainer = new PartialDependencePlotExplainer();
    List<PartialDependenceGraph> pdps = partialDependencePlotExplainer.explainFromPredictions(model, predictions);
    AssertionsForClassTypes.assertThat(pdps).isNotNull();
    Assertions.assertThat(pdps).hasSize(8);
}
Also used : SimplePrediction(org.kie.kogito.explainability.model.SimplePrediction) InputStreamReader(java.io.InputStreamReader) PredictionInput(org.kie.kogito.explainability.model.PredictionInput) SimplePrediction(org.kie.kogito.explainability.model.SimplePrediction) Prediction(org.kie.kogito.explainability.model.Prediction) ArrayList(java.util.ArrayList) DecisionModel(org.kie.kogito.decision.DecisionModel) DmnDecisionModel(org.kie.kogito.dmn.DmnDecisionModel) PredictionProvider(org.kie.kogito.explainability.model.PredictionProvider) DMNRuntime(org.kie.dmn.api.core.DMNRuntime) PartialDependencePlotExplainer(org.kie.kogito.explainability.global.pdp.PartialDependencePlotExplainer) PredictionOutput(org.kie.kogito.explainability.model.PredictionOutput) DmnDecisionModel(org.kie.kogito.dmn.DmnDecisionModel) PartialDependenceGraph(org.kie.kogito.explainability.model.PartialDependenceGraph) Test(org.junit.jupiter.api.Test)

Example 10 with PartialDependenceGraph

use of org.kie.kogito.explainability.model.PartialDependenceGraph in project kogito-apps by kiegroup.

the class PrequalificationDmnPDPExplainerTest method testPrequalificationDMNExplanation.

@Test
void testPrequalificationDMNExplanation() throws ExecutionException, InterruptedException, TimeoutException {
    DMNRuntime dmnRuntime = DMNKogito.createGenericDMNRuntime(new InputStreamReader(getClass().getResourceAsStream("/dmn/Prequalification-1.dmn")));
    assertEquals(1, dmnRuntime.getModels().size());
    final String NS = "http://www.trisotech.com/definitions/_f31e1f8e-d4ce-4a3a-ac3b-747efa6b3401";
    final String NAME = "Prequalification";
    DecisionModel decisionModel = new DmnDecisionModel(dmnRuntime, NS, NAME);
    PredictionProvider model = new DecisionModelWrapper(decisionModel);
    List<PredictionInput> inputs = DmnTestUtils.randomPrequalificationInputs();
    List<PredictionOutput> predictionOutputs = model.predictAsync(inputs).get(Config.INSTANCE.getAsyncTimeout(), Config.INSTANCE.getAsyncTimeUnit());
    List<Prediction> predictions = new ArrayList<>();
    for (int i = 0; i < predictionOutputs.size(); i++) {
        predictions.add(new SimplePrediction(inputs.get(i), predictionOutputs.get(i)));
    }
    PartialDependencePlotExplainer partialDependencePlotExplainer = new PartialDependencePlotExplainer();
    List<PartialDependenceGraph> pdps = partialDependencePlotExplainer.explainFromPredictions(model, predictions);
    AssertionsForClassTypes.assertThat(pdps).isNotNull();
    Assertions.assertThat(pdps).hasSize(25);
}
Also used : SimplePrediction(org.kie.kogito.explainability.model.SimplePrediction) InputStreamReader(java.io.InputStreamReader) PredictionInput(org.kie.kogito.explainability.model.PredictionInput) SimplePrediction(org.kie.kogito.explainability.model.SimplePrediction) Prediction(org.kie.kogito.explainability.model.Prediction) ArrayList(java.util.ArrayList) DecisionModel(org.kie.kogito.decision.DecisionModel) DmnDecisionModel(org.kie.kogito.dmn.DmnDecisionModel) PredictionProvider(org.kie.kogito.explainability.model.PredictionProvider) DMNRuntime(org.kie.dmn.api.core.DMNRuntime) PartialDependencePlotExplainer(org.kie.kogito.explainability.global.pdp.PartialDependencePlotExplainer) PredictionOutput(org.kie.kogito.explainability.model.PredictionOutput) DmnDecisionModel(org.kie.kogito.dmn.DmnDecisionModel) PartialDependenceGraph(org.kie.kogito.explainability.model.PartialDependenceGraph) Test(org.junit.jupiter.api.Test)

Aggregations

PartialDependenceGraph (org.kie.kogito.explainability.model.PartialDependenceGraph)10 ArrayList (java.util.ArrayList)9 PredictionInput (org.kie.kogito.explainability.model.PredictionInput)8 PredictionOutput (org.kie.kogito.explainability.model.PredictionOutput)8 PredictionProvider (org.kie.kogito.explainability.model.PredictionProvider)8 Prediction (org.kie.kogito.explainability.model.Prediction)7 Test (org.junit.jupiter.api.Test)6 SimplePrediction (org.kie.kogito.explainability.model.SimplePrediction)6 PartialDependencePlotExplainer (org.kie.kogito.explainability.global.pdp.PartialDependencePlotExplainer)5 Feature (org.kie.kogito.explainability.model.Feature)5 InputStreamReader (java.io.InputStreamReader)4 DMNRuntime (org.kie.dmn.api.core.DMNRuntime)4 DecisionModel (org.kie.kogito.decision.DecisionModel)4 DmnDecisionModel (org.kie.kogito.dmn.DmnDecisionModel)4 Output (org.kie.kogito.explainability.model.Output)4 Value (org.kie.kogito.explainability.model.Value)4 ParameterizedTest (org.junit.jupiter.params.ParameterizedTest)3 HashMap (java.util.HashMap)2 List (java.util.List)2 Map (java.util.Map)2