Search in sources :

Example 91 with PredictionProvider

use of org.kie.kogito.explainability.model.PredictionProvider in project kogito-apps by kiegroup.

the class AggregatedLimeExplainerTest method testExplainWithPredictions.

@ParameterizedTest
@ValueSource(ints = { 0, 1, 2, 3, 4 })
void testExplainWithPredictions(int seed) throws ExecutionException, InterruptedException {
    Random random = new Random();
    random.setSeed(seed);
    PredictionProvider sumSkipModel = TestUtils.getSumSkipModel(1);
    DataDistribution dataDistribution = DataUtils.generateRandomDataDistribution(3, 100, random);
    List<PredictionInput> samples = dataDistribution.sample(10);
    List<PredictionOutput> predictionOutputs = sumSkipModel.predictAsync(samples).get();
    List<Prediction> predictions = DataUtils.getPredictions(samples, predictionOutputs);
    AggregatedLimeExplainer aggregatedLimeExplainer = new AggregatedLimeExplainer();
    Map<String, Saliency> explain = aggregatedLimeExplainer.explainFromPredictions(sumSkipModel, predictions).get();
    assertNotNull(explain);
    assertEquals(1, explain.size());
    assertTrue(explain.containsKey("sum-but1"));
    Saliency saliency = explain.get("sum-but1");
    assertNotNull(saliency);
    List<String> collect = saliency.getPositiveFeatures(2).stream().map(FeatureImportance::getFeature).map(Feature::getName).collect(Collectors.toList());
    // skipped feature should not appear in top two positive features
    assertFalse(collect.contains("f1"));
}
Also used : PredictionInput(org.kie.kogito.explainability.model.PredictionInput) Prediction(org.kie.kogito.explainability.model.Prediction) Saliency(org.kie.kogito.explainability.model.Saliency) PredictionProvider(org.kie.kogito.explainability.model.PredictionProvider) Random(java.util.Random) DataDistribution(org.kie.kogito.explainability.model.DataDistribution) FeatureImportance(org.kie.kogito.explainability.model.FeatureImportance) PredictionOutput(org.kie.kogito.explainability.model.PredictionOutput) ValueSource(org.junit.jupiter.params.provider.ValueSource) ParameterizedTest(org.junit.jupiter.params.ParameterizedTest)

Example 92 with PredictionProvider

use of org.kie.kogito.explainability.model.PredictionProvider in project kogito-apps by kiegroup.

the class PartialDependencePlotExplainerTest method testTextClassifier.

@ParameterizedTest
@ValueSource(ints = { 0, 1, 2, 3, 4 })
void testTextClassifier(int seed) throws Exception {
    Random random = new Random();
    random.setSeed(seed);
    PartialDependencePlotExplainer partialDependencePlotExplainer = new PartialDependencePlotExplainer();
    PredictionProvider model = TestUtils.getDummyTextClassifier();
    Collection<Prediction> predictions = new ArrayList<>(3);
    List<String> texts = List.of("we want your money", "please reply quickly", "you are the lucky winner", "huge donation for you!", "bitcoin for you");
    for (String text : texts) {
        List<Feature> features = new ArrayList<>();
        features.add(FeatureFactory.newFulltextFeature("text", text));
        PredictionInput predictionInput = new PredictionInput(features);
        PredictionOutput predictionOutput = model.predictAsync(List.of(predictionInput)).get().get(0);
        predictions.add(new SimplePrediction(predictionInput, predictionOutput));
    }
    List<PartialDependenceGraph> pdps = partialDependencePlotExplainer.explainFromPredictions(model, predictions);
    assertThat(pdps).isNotEmpty();
}
Also used : SimplePrediction(org.kie.kogito.explainability.model.SimplePrediction) PredictionInput(org.kie.kogito.explainability.model.PredictionInput) Prediction(org.kie.kogito.explainability.model.Prediction) SimplePrediction(org.kie.kogito.explainability.model.SimplePrediction) ArrayList(java.util.ArrayList) PredictionProvider(org.kie.kogito.explainability.model.PredictionProvider) Feature(org.kie.kogito.explainability.model.Feature) Random(java.util.Random) PredictionOutput(org.kie.kogito.explainability.model.PredictionOutput) PartialDependenceGraph(org.kie.kogito.explainability.model.PartialDependenceGraph) ValueSource(org.junit.jupiter.params.provider.ValueSource) ParameterizedTest(org.junit.jupiter.params.ParameterizedTest)

Example 93 with PredictionProvider

use of org.kie.kogito.explainability.model.PredictionProvider in project kogito-apps by kiegroup.

the class CounterfactualExplainerTest method testConsumers.

@ParameterizedTest
@ValueSource(ints = { 0, 1, 2 })
void testConsumers(int seed) throws ExecutionException, InterruptedException, TimeoutException {
    Random random = new Random();
    random.setSeed(seed);
    final List<Output> goal = List.of(new Output("inside", Type.BOOLEAN, new Value(true), 0.0));
    List<Feature> features = new LinkedList<>();
    features.add(FeatureFactory.newNumericalFeature("f-num1", 100.0, NumericalFeatureDomain.create(0.0, 1000.0)));
    features.add(FeatureFactory.newNumericalFeature("f-num2", 100.0, NumericalFeatureDomain.create(0.0, 1000.0)));
    features.add(FeatureFactory.newNumericalFeature("f-num3", 100.0, NumericalFeatureDomain.create(0.0, 1000.0)));
    features.add(FeatureFactory.newNumericalFeature("f-num4", 100.0, NumericalFeatureDomain.create(0.0, 1000.0)));
    final TerminationConfig terminationConfig = new TerminationConfig().withScoreCalculationCountLimit(10_000L);
    // for the purpose of this test, only a few steps are necessary
    final SolverConfig solverConfig = SolverConfigBuilder.builder().withTerminationConfig(terminationConfig).build();
    solverConfig.setRandomSeed((long) seed);
    solverConfig.setEnvironmentMode(EnvironmentMode.REPRODUCIBLE);
    @SuppressWarnings("unchecked") final Consumer<CounterfactualResult> assertIntermediateCounterfactualNotNull = mock(Consumer.class);
    final CounterfactualConfig counterfactualConfig = new CounterfactualConfig().withSolverConfig(solverConfig).withGoalThreshold(0.01);
    final CounterfactualExplainer counterfactualExplainer = new CounterfactualExplainer(counterfactualConfig);
    PredictionInput input = new PredictionInput(features);
    final double center = 500.0;
    final double epsilon = 10.0;
    final PredictionProvider model = TestUtils.getSumThresholdModel(center, epsilon);
    PredictionOutput output = new PredictionOutput(goal);
    Prediction prediction = new CounterfactualPrediction(input, output, null, UUID.randomUUID(), null);
    final CounterfactualResult counterfactualResult = counterfactualExplainer.explainAsync(prediction, model, assertIntermediateCounterfactualNotNull).get(Config.INSTANCE.getAsyncTimeout(), Config.INSTANCE.getAsyncTimeUnit());
    for (CounterfactualEntity entity : counterfactualResult.getEntities()) {
        logger.debug("Entity: {}", entity);
    }
    logger.debug("Outputs: {}", counterfactualResult.getOutput().get(0).getOutputs());
    // At least one intermediate result is generated
    verify(assertIntermediateCounterfactualNotNull, atLeast(1)).accept(any());
}
Also used : PredictionInput(org.kie.kogito.explainability.model.PredictionInput) Prediction(org.kie.kogito.explainability.model.Prediction) CounterfactualPrediction(org.kie.kogito.explainability.model.CounterfactualPrediction) PredictionProvider(org.kie.kogito.explainability.model.PredictionProvider) Feature(org.kie.kogito.explainability.model.Feature) LinkedList(java.util.LinkedList) CounterfactualPrediction(org.kie.kogito.explainability.model.CounterfactualPrediction) CounterfactualEntity(org.kie.kogito.explainability.local.counterfactual.entities.CounterfactualEntity) TerminationConfig(org.optaplanner.core.config.solver.termination.TerminationConfig) Random(java.util.Random) PredictionOutput(org.kie.kogito.explainability.model.PredictionOutput) PredictionOutput(org.kie.kogito.explainability.model.PredictionOutput) Output(org.kie.kogito.explainability.model.Output) Value(org.kie.kogito.explainability.model.Value) SolverConfig(org.optaplanner.core.config.solver.SolverConfig) ValueSource(org.junit.jupiter.params.provider.ValueSource) ParameterizedTest(org.junit.jupiter.params.ParameterizedTest)

Example 94 with PredictionProvider

use of org.kie.kogito.explainability.model.PredictionProvider in project kogito-apps by kiegroup.

the class CounterfactualExplainerTest method testNonEmptyInput.

@ParameterizedTest
@ValueSource(ints = { 0, 1, 2 })
void testNonEmptyInput(int seed) throws ExecutionException, InterruptedException, TimeoutException {
    Random random = new Random();
    random.setSeed(seed);
    final List<Output> goal = List.of(new Output("class", Type.NUMBER, new Value(10.0), 0.0d));
    List<Feature> features = new LinkedList<>();
    for (int i = 0; i < 4; i++) {
        features.add(FeatureFactory.newNumericalFeature("f-" + i, random.nextDouble(), NumericalFeatureDomain.create(0.0, 1000.0)));
    }
    final TerminationConfig terminationConfig = new TerminationConfig().withScoreCalculationCountLimit(10L);
    // for the purpose of this test, only a few steps are necessary
    final SolverConfig solverConfig = SolverConfigBuilder.builder().withTerminationConfig(terminationConfig).build();
    solverConfig.setRandomSeed((long) seed);
    solverConfig.setEnvironmentMode(EnvironmentMode.REPRODUCIBLE);
    final CounterfactualConfig counterfactualConfig = new CounterfactualConfig().withSolverConfig(solverConfig);
    final CounterfactualExplainer counterfactualExplainer = new CounterfactualExplainer(counterfactualConfig);
    PredictionProvider model = TestUtils.getSumSkipModel(0);
    PredictionInput input = new PredictionInput(features);
    PredictionOutput output = new PredictionOutput(goal);
    Prediction prediction = new CounterfactualPrediction(input, output, null, UUID.randomUUID(), null);
    final CounterfactualResult counterfactualResult = counterfactualExplainer.explainAsync(prediction, model).get(Config.INSTANCE.getAsyncTimeout(), Config.INSTANCE.getAsyncTimeUnit());
    for (CounterfactualEntity entity : counterfactualResult.getEntities()) {
        logger.debug("Entity: {}", entity);
    }
    logger.debug("Outputs: {}", counterfactualResult.getOutput().get(0).getOutputs());
    assertNotNull(counterfactualResult);
    assertNotNull(counterfactualResult.getEntities());
}
Also used : PredictionInput(org.kie.kogito.explainability.model.PredictionInput) Prediction(org.kie.kogito.explainability.model.Prediction) CounterfactualPrediction(org.kie.kogito.explainability.model.CounterfactualPrediction) PredictionProvider(org.kie.kogito.explainability.model.PredictionProvider) Feature(org.kie.kogito.explainability.model.Feature) LinkedList(java.util.LinkedList) CounterfactualPrediction(org.kie.kogito.explainability.model.CounterfactualPrediction) CounterfactualEntity(org.kie.kogito.explainability.local.counterfactual.entities.CounterfactualEntity) TerminationConfig(org.optaplanner.core.config.solver.termination.TerminationConfig) Random(java.util.Random) PredictionOutput(org.kie.kogito.explainability.model.PredictionOutput) PredictionOutput(org.kie.kogito.explainability.model.PredictionOutput) Output(org.kie.kogito.explainability.model.Output) Value(org.kie.kogito.explainability.model.Value) SolverConfig(org.optaplanner.core.config.solver.SolverConfig) ValueSource(org.junit.jupiter.params.provider.ValueSource) ParameterizedTest(org.junit.jupiter.params.ParameterizedTest)

Example 95 with PredictionProvider

use of org.kie.kogito.explainability.model.PredictionProvider in project kogito-apps by kiegroup.

the class CounterfactualExplainerTest method testCounterfactualMatchThreshold.

@ParameterizedTest
@ValueSource(ints = { 0, 1, 2 })
void testCounterfactualMatchThreshold(int seed) throws ExecutionException, InterruptedException, TimeoutException {
    Random random = new Random();
    random.setSeed(seed);
    final double scoreThreshold = 0.9;
    final List<Output> goal = List.of(new Output("inside", Type.BOOLEAN, new Value(true), scoreThreshold));
    List<Feature> features = new LinkedList<>();
    features.add(FeatureFactory.newNumericalFeature("f-num1", 100.0, NumericalFeatureDomain.create(0.0, 1000.0)));
    features.add(FeatureFactory.newNumericalFeature("f-num2", 100.0, NumericalFeatureDomain.create(0.0, 1000.0)));
    features.add(FeatureFactory.newNumericalFeature("f-num3", 100.0, NumericalFeatureDomain.create(0.0, 1000.0)));
    features.add(FeatureFactory.newNumericalFeature("f-num4", 100.0, NumericalFeatureDomain.create(0.0, 1000.0)));
    final double center = 500.0;
    final double epsilon = 10.0;
    final PredictionProvider model = TestUtils.getSumThresholdModel(center, epsilon);
    final CounterfactualResult result = runCounterfactualSearch((long) seed, goal, features, model, DEFAULT_GOAL_THRESHOLD);
    final List<CounterfactualEntity> counterfactualEntities = result.getEntities();
    double totalSum = 0;
    for (CounterfactualEntity entity : counterfactualEntities) {
        totalSum += entity.asFeature().getValue().asNumber();
        logger.debug("Entity: {}", entity);
    }
    assertTrue(totalSum <= center + epsilon);
    assertTrue(totalSum >= center - epsilon);
    final List<Feature> cfFeatures = counterfactualEntities.stream().map(CounterfactualEntity::asFeature).collect(Collectors.toList());
    final PredictionInput cfInput = new PredictionInput(cfFeatures);
    final PredictionOutput cfOutput = model.predictAsync(List.of(cfInput)).get(Config.INSTANCE.getAsyncTimeout(), Config.INSTANCE.getAsyncTimeUnit()).get(0);
    final double predictionScore = cfOutput.getOutputs().get(0).getScore();
    logger.debug("Prediction score: {}", predictionScore);
    assertTrue(predictionScore >= scoreThreshold);
    assertTrue(result.isValid());
}
Also used : PredictionInput(org.kie.kogito.explainability.model.PredictionInput) PredictionProvider(org.kie.kogito.explainability.model.PredictionProvider) Feature(org.kie.kogito.explainability.model.Feature) LinkedList(java.util.LinkedList) CounterfactualEntity(org.kie.kogito.explainability.local.counterfactual.entities.CounterfactualEntity) Random(java.util.Random) PredictionOutput(org.kie.kogito.explainability.model.PredictionOutput) PredictionOutput(org.kie.kogito.explainability.model.PredictionOutput) Output(org.kie.kogito.explainability.model.Output) Value(org.kie.kogito.explainability.model.Value) ValueSource(org.junit.jupiter.params.provider.ValueSource) ParameterizedTest(org.junit.jupiter.params.ParameterizedTest)

Aggregations

PredictionProvider (org.kie.kogito.explainability.model.PredictionProvider)158 Prediction (org.kie.kogito.explainability.model.Prediction)134 PredictionInput (org.kie.kogito.explainability.model.PredictionInput)134 PredictionOutput (org.kie.kogito.explainability.model.PredictionOutput)126 Test (org.junit.jupiter.api.Test)109 SimplePrediction (org.kie.kogito.explainability.model.SimplePrediction)99 Random (java.util.Random)91 Feature (org.kie.kogito.explainability.model.Feature)76 ArrayList (java.util.ArrayList)73 PerturbationContext (org.kie.kogito.explainability.model.PerturbationContext)69 ParameterizedTest (org.junit.jupiter.params.ParameterizedTest)64 LimeConfig (org.kie.kogito.explainability.local.lime.LimeConfig)59 LimeExplainer (org.kie.kogito.explainability.local.lime.LimeExplainer)54 Output (org.kie.kogito.explainability.model.Output)45 Saliency (org.kie.kogito.explainability.model.Saliency)45 LinkedList (java.util.LinkedList)41 Value (org.kie.kogito.explainability.model.Value)41 List (java.util.List)37 LimeConfigOptimizer (org.kie.kogito.explainability.local.lime.optim.LimeConfigOptimizer)33 ValueSource (org.junit.jupiter.params.provider.ValueSource)32