Search in sources :

Example 26 with DataDistribution

use of org.kie.kogito.explainability.model.DataDistribution in project kogito-apps by kiegroup.

the class LimeConfigTest method testDataDistribution.

@Test
void testDataDistribution() {
    DataDistribution dd = mock(DataDistribution.class);
    LimeConfig config = new LimeConfig().withDataDistribution(dd);
    assertThat(config.getDataDistribution()).isEqualTo(dd);
}
Also used : DataDistribution(org.kie.kogito.explainability.model.DataDistribution) Test(org.junit.jupiter.api.Test)

Example 27 with DataDistribution

use of org.kie.kogito.explainability.model.DataDistribution in project kogito-apps by kiegroup.

the class PartialDependencePlotExplainer method explainFromDataDistribution.

private List<PartialDependenceGraph> explainFromDataDistribution(PredictionProvider model, int outputSize, DataDistribution dataDistribution) throws InterruptedException, ExecutionException, TimeoutException {
    long start = System.currentTimeMillis();
    List<PartialDependenceGraph> pdps = new ArrayList<>();
    List<FeatureDistribution> featureDistributions = dataDistribution.asFeatureDistributions();
    // fetch entire data distributions for all features
    List<PredictionInput> trainingData = dataDistribution.sample(config.getSeriesLength());
    // create a PDP for each feature
    for (FeatureDistribution featureDistribution : featureDistributions) {
        // generate (further) samples for the feature under analysis
        // TBD: maybe just reuse trainingData
        List<Value> xsValues = featureDistribution.sample(config.getSeriesLength()).stream().sorted(// sort alphanumerically (if Value#asNumber is NaN)
        Comparator.comparing(Value::asString)).sorted(// sort by natural order
        (v1, v2) -> Comparator.comparingDouble(Value::asNumber).compare(v1, v2)).distinct().collect(Collectors.toList());
        List<Feature> featureXSvalues = // transform sampled Values into Features
        xsValues.stream().map(v -> FeatureFactory.copyOf(featureDistribution.getFeature(), v)).collect(Collectors.toList());
        // create a PDP for each feature and each output
        for (int outputIndex = 0; outputIndex < outputSize; outputIndex++) {
            PartialDependenceGraph partialDependenceGraph = getPartialDependenceGraph(model, trainingData, xsValues, featureXSvalues, outputIndex);
            pdps.add(partialDependenceGraph);
        }
    }
    long end = System.currentTimeMillis();
    LOGGER.debug("explanation time: {}ms", (end - start));
    return pdps;
}
Also used : FeatureFactory(org.kie.kogito.explainability.model.FeatureFactory) PredictionInputsDataDistribution(org.kie.kogito.explainability.model.PredictionInputsDataDistribution) Feature(org.kie.kogito.explainability.model.Feature) Prediction(org.kie.kogito.explainability.model.Prediction) LoggerFactory(org.slf4j.LoggerFactory) TimeoutException(java.util.concurrent.TimeoutException) HashMap(java.util.HashMap) Value(org.kie.kogito.explainability.model.Value) DataDistribution(org.kie.kogito.explainability.model.DataDistribution) ArrayList(java.util.ArrayList) PartialDependenceGraph(org.kie.kogito.explainability.model.PartialDependenceGraph) Map(java.util.Map) FeatureDistribution(org.kie.kogito.explainability.model.FeatureDistribution) GlobalExplainer(org.kie.kogito.explainability.global.GlobalExplainer) PredictionOutput(org.kie.kogito.explainability.model.PredictionOutput) Logger(org.slf4j.Logger) Collection(java.util.Collection) Collectors(java.util.stream.Collectors) Type(org.kie.kogito.explainability.model.Type) PredictionProvider(org.kie.kogito.explainability.model.PredictionProvider) PredictionProviderMetadata(org.kie.kogito.explainability.model.PredictionProviderMetadata) ExecutionException(java.util.concurrent.ExecutionException) PredictionInput(org.kie.kogito.explainability.model.PredictionInput) List(java.util.List) Output(org.kie.kogito.explainability.model.Output) Comparator(java.util.Comparator) Config(org.kie.kogito.explainability.Config) PredictionInput(org.kie.kogito.explainability.model.PredictionInput) ArrayList(java.util.ArrayList) Feature(org.kie.kogito.explainability.model.Feature) FeatureDistribution(org.kie.kogito.explainability.model.FeatureDistribution) Value(org.kie.kogito.explainability.model.Value) PartialDependenceGraph(org.kie.kogito.explainability.model.PartialDependenceGraph)

Example 28 with DataDistribution

use of org.kie.kogito.explainability.model.DataDistribution in project kogito-apps by kiegroup.

the class LimeExplainer method getPerturbedInputs.

private List<PredictionInput> getPerturbedInputs(List<Feature> features, LimeConfig executionConfig, PredictionProvider predictionProvider) {
    List<PredictionInput> perturbedInputs = new ArrayList<>();
    int size = executionConfig.getNoOfSamples();
    DataDistribution dataDistribution = executionConfig.getDataDistribution();
    Map<String, FeatureDistribution> featureDistributionsMap;
    PerturbationContext perturbationContext = executionConfig.getPerturbationContext();
    if (!dataDistribution.isEmpty()) {
        Map<String, HighScoreNumericFeatureZones> numericFeatureZonesMap;
        int max = executionConfig.getBoostrapInputs();
        if (executionConfig.isHighScoreFeatureZones()) {
            numericFeatureZonesMap = HighScoreNumericFeatureZonesProvider.getHighScoreFeatureZones(dataDistribution, predictionProvider, features, max);
        } else {
            numericFeatureZonesMap = new HashMap<>();
        }
        // generate feature distributions, if possible
        featureDistributionsMap = DataUtils.boostrapFeatureDistributions(dataDistribution, perturbationContext, 2 * size, 1, Math.min(size, max), numericFeatureZonesMap);
    } else {
        featureDistributionsMap = new HashMap<>();
    }
    for (int i = 0; i < size; i++) {
        List<Feature> newFeatures = DataUtils.perturbFeatures(features, perturbationContext, featureDistributionsMap);
        perturbedInputs.add(new PredictionInput(newFeatures));
    }
    return perturbedInputs;
}
Also used : PerturbationContext(org.kie.kogito.explainability.model.PerturbationContext) PredictionInput(org.kie.kogito.explainability.model.PredictionInput) ArrayList(java.util.ArrayList) Feature(org.kie.kogito.explainability.model.Feature) FeatureDistribution(org.kie.kogito.explainability.model.FeatureDistribution) DataDistribution(org.kie.kogito.explainability.model.DataDistribution)

Example 29 with DataDistribution

use of org.kie.kogito.explainability.model.DataDistribution in project kogito-apps by kiegroup.

the class TrafficViolationDmnLimeExplainerTest method testTrafficViolationDMNExplanation.

@Test
void testTrafficViolationDMNExplanation() throws ExecutionException, InterruptedException, TimeoutException {
    PredictionProvider model = getModel();
    PredictionInput predictionInput = getTestInput();
    List<PredictionOutput> predictionOutputs = model.predictAsync(List.of(predictionInput)).get(Config.INSTANCE.getAsyncTimeout(), Config.INSTANCE.getAsyncTimeUnit());
    Prediction prediction = new SimplePrediction(predictionInput, predictionOutputs.get(0));
    Random random = new Random();
    PerturbationContext perturbationContext = new PerturbationContext(0L, random, 1);
    LimeConfig limeConfig = new LimeConfig().withSamples(10).withPerturbationContext(perturbationContext);
    LimeExplainer limeExplainer = new LimeExplainer(limeConfig);
    Map<String, Saliency> saliencyMap = limeExplainer.explainAsync(prediction, model).get(Config.INSTANCE.getAsyncTimeout(), Config.INSTANCE.getAsyncTimeUnit());
    for (Saliency saliency : saliencyMap.values()) {
        assertNotNull(saliency);
        List<String> strings = saliency.getTopFeatures(3).stream().map(f -> f.getFeature().getName()).collect(Collectors.toList());
        assertTrue(strings.contains("Actual Speed") || strings.contains("Speed Limit"));
    }
    assertDoesNotThrow(() -> ValidationUtils.validateLocalSaliencyStability(model, prediction, limeExplainer, 1, 0.3, 0.3));
    String decision = "Fine";
    List<PredictionInput> inputs = new ArrayList<>();
    for (int n = 0; n < 10; n++) {
        inputs.add(new PredictionInput(DataUtils.perturbFeatures(predictionInput.getFeatures(), perturbationContext)));
    }
    DataDistribution distribution = new PredictionInputsDataDistribution(inputs);
    int k = 2;
    int chunkSize = 5;
    double f1 = ExplainabilityMetrics.getLocalSaliencyF1(decision, model, limeExplainer, distribution, k, chunkSize);
    AssertionsForClassTypes.assertThat(f1).isBetween(0.5d, 1d);
}
Also used : FeatureFactory(org.kie.kogito.explainability.model.FeatureFactory) Assertions.assertNotNull(org.junit.jupiter.api.Assertions.assertNotNull) DecisionModel(org.kie.kogito.decision.DecisionModel) PredictionInputsDataDistribution(org.kie.kogito.explainability.model.PredictionInputsDataDistribution) PerturbationContext(org.kie.kogito.explainability.model.PerturbationContext) Feature(org.kie.kogito.explainability.model.Feature) Prediction(org.kie.kogito.explainability.model.Prediction) Assertions.assertThat(org.assertj.core.api.Assertions.assertThat) AssertionsForClassTypes(org.assertj.core.api.AssertionsForClassTypes) TimeoutException(java.util.concurrent.TimeoutException) DmnDecisionModel(org.kie.kogito.dmn.DmnDecisionModel) HashMap(java.util.HashMap) Random(java.util.Random) DataDistribution(org.kie.kogito.explainability.model.DataDistribution) Saliency(org.kie.kogito.explainability.model.Saliency) ArrayList(java.util.ArrayList) Map(java.util.Map) LimeConfig(org.kie.kogito.explainability.local.lime.LimeConfig) DMNRuntime(org.kie.dmn.api.core.DMNRuntime) Assertions.assertEquals(org.junit.jupiter.api.Assertions.assertEquals) LinkedList(java.util.LinkedList) PredictionOutput(org.kie.kogito.explainability.model.PredictionOutput) LimeConfigOptimizer(org.kie.kogito.explainability.local.lime.optim.LimeConfigOptimizer) SimplePrediction(org.kie.kogito.explainability.model.SimplePrediction) LimeExplainer(org.kie.kogito.explainability.local.lime.LimeExplainer) DataUtils(org.kie.kogito.explainability.utils.DataUtils) InputStreamReader(java.io.InputStreamReader) Collectors(java.util.stream.Collectors) DMNKogito(org.kie.kogito.dmn.DMNKogito) PredictionProvider(org.kie.kogito.explainability.model.PredictionProvider) ExecutionException(java.util.concurrent.ExecutionException) Test(org.junit.jupiter.api.Test) PredictionInput(org.kie.kogito.explainability.model.PredictionInput) List(java.util.List) ExplainabilityMetrics(org.kie.kogito.explainability.utils.ExplainabilityMetrics) Assertions.assertTrue(org.junit.jupiter.api.Assertions.assertTrue) ValidationUtils(org.kie.kogito.explainability.utils.ValidationUtils) Config(org.kie.kogito.explainability.Config) Assertions.assertDoesNotThrow(org.junit.jupiter.api.Assertions.assertDoesNotThrow) SimplePrediction(org.kie.kogito.explainability.model.SimplePrediction) PerturbationContext(org.kie.kogito.explainability.model.PerturbationContext) PredictionInput(org.kie.kogito.explainability.model.PredictionInput) LimeExplainer(org.kie.kogito.explainability.local.lime.LimeExplainer) Prediction(org.kie.kogito.explainability.model.Prediction) SimplePrediction(org.kie.kogito.explainability.model.SimplePrediction) ArrayList(java.util.ArrayList) Saliency(org.kie.kogito.explainability.model.Saliency) PredictionProvider(org.kie.kogito.explainability.model.PredictionProvider) LimeConfig(org.kie.kogito.explainability.local.lime.LimeConfig) Random(java.util.Random) PredictionInputsDataDistribution(org.kie.kogito.explainability.model.PredictionInputsDataDistribution) DataDistribution(org.kie.kogito.explainability.model.DataDistribution) PredictionOutput(org.kie.kogito.explainability.model.PredictionOutput) PredictionInputsDataDistribution(org.kie.kogito.explainability.model.PredictionInputsDataDistribution) Test(org.junit.jupiter.api.Test)

Example 30 with DataDistribution

use of org.kie.kogito.explainability.model.DataDistribution in project kogito-apps by kiegroup.

the class PmmlRegressionCategoricalLimeExplainerTest method testPMMLRegressionCategorical.

@Disabled("See KOGITO-6154")
@Test
void testPMMLRegressionCategorical() throws Exception {
    PredictionInput input = getTestInput();
    Random random = new Random();
    LimeConfig limeConfig = new LimeConfig().withSamples(10).withAdaptiveVariance(true).withPerturbationContext(new PerturbationContext(0L, random, 1));
    LimeExplainer limeExplainer = new LimeExplainer(limeConfig);
    PredictionProvider model = getModel();
    List<PredictionOutput> predictionOutputs = model.predictAsync(List.of(input)).get(Config.INSTANCE.getAsyncTimeout(), Config.INSTANCE.getAsyncTimeUnit());
    assertThat(predictionOutputs).isNotNull().isNotEmpty();
    PredictionOutput output = predictionOutputs.get(0);
    assertThat(output).isNotNull();
    Prediction prediction = new SimplePrediction(input, output);
    Map<String, Saliency> saliencyMap = limeExplainer.explainAsync(prediction, model).get(Config.INSTANCE.getAsyncTimeout(), Config.INSTANCE.getAsyncTimeUnit());
    for (Saliency saliency : saliencyMap.values()) {
        assertThat(saliency).isNotNull();
        double v = ExplainabilityMetrics.impactScore(model, prediction, saliency.getTopFeatures(2));
        assertThat(v).isEqualTo(1d);
    }
    assertDoesNotThrow(() -> ValidationUtils.validateLocalSaliencyStability(model, prediction, limeExplainer, 1, 0.5, 0.5));
    List<PredictionInput> inputs = getSamples();
    DataDistribution distribution = new PredictionInputsDataDistribution(inputs);
    String decision = "result";
    int k = 1;
    int chunkSize = 2;
    double f1 = ExplainabilityMetrics.getLocalSaliencyF1(decision, model, limeExplainer, distribution, k, chunkSize);
    AssertionsForClassTypes.assertThat(f1).isBetween(0d, 1d);
}
Also used : SimplePrediction(org.kie.kogito.explainability.model.SimplePrediction) PerturbationContext(org.kie.kogito.explainability.model.PerturbationContext) PredictionInput(org.kie.kogito.explainability.model.PredictionInput) LimeExplainer(org.kie.kogito.explainability.local.lime.LimeExplainer) Prediction(org.kie.kogito.explainability.model.Prediction) SimplePrediction(org.kie.kogito.explainability.model.SimplePrediction) Saliency(org.kie.kogito.explainability.model.Saliency) PredictionProvider(org.kie.kogito.explainability.model.PredictionProvider) LimeConfig(org.kie.kogito.explainability.local.lime.LimeConfig) Random(java.util.Random) PredictionInputsDataDistribution(org.kie.kogito.explainability.model.PredictionInputsDataDistribution) DataDistribution(org.kie.kogito.explainability.model.DataDistribution) PredictionOutput(org.kie.kogito.explainability.model.PredictionOutput) PredictionInputsDataDistribution(org.kie.kogito.explainability.model.PredictionInputsDataDistribution) Test(org.junit.jupiter.api.Test) Disabled(org.junit.jupiter.api.Disabled)

Aggregations

DataDistribution (org.kie.kogito.explainability.model.DataDistribution)32 PredictionProvider (org.kie.kogito.explainability.model.PredictionProvider)27 PerturbationContext (org.kie.kogito.explainability.model.PerturbationContext)25 Prediction (org.kie.kogito.explainability.model.Prediction)25 PredictionInput (org.kie.kogito.explainability.model.PredictionInput)25 ArrayList (java.util.ArrayList)24 Random (java.util.Random)24 PredictionOutput (org.kie.kogito.explainability.model.PredictionOutput)24 PredictionInputsDataDistribution (org.kie.kogito.explainability.model.PredictionInputsDataDistribution)21 Saliency (org.kie.kogito.explainability.model.Saliency)20 SimplePrediction (org.kie.kogito.explainability.model.SimplePrediction)20 Test (org.junit.jupiter.api.Test)19 Feature (org.kie.kogito.explainability.model.Feature)18 LimeConfig (org.kie.kogito.explainability.local.lime.LimeConfig)14 ParameterizedTest (org.junit.jupiter.params.ParameterizedTest)12 LimeExplainer (org.kie.kogito.explainability.local.lime.LimeExplainer)12 FeatureImportance (org.kie.kogito.explainability.model.FeatureImportance)11 LinkedList (java.util.LinkedList)9 ValueSource (org.junit.jupiter.params.provider.ValueSource)8 FeatureDistribution (org.kie.kogito.explainability.model.FeatureDistribution)8