Search in sources :

Example 96 with PredictionProvider

use of org.kie.kogito.explainability.model.PredictionProvider in project kogito-apps by kiegroup.

the class CounterfactualExplainerTest method testCounterfactualMatchNoThreshold.

@ParameterizedTest
@ValueSource(ints = { 0, 1, 2 })
void testCounterfactualMatchNoThreshold(int seed) throws ExecutionException, InterruptedException, TimeoutException {
    Random random = new Random();
    random.setSeed(seed);
    final double scoreThreshold = 0.0;
    final List<Output> goal = List.of(new Output("inside", Type.BOOLEAN, new Value(true), scoreThreshold));
    List<Feature> features = new LinkedList<>();
    features.add(FeatureFactory.newNumericalFeature("f-num1", 100.0, NumericalFeatureDomain.create(0.0, 1000.0)));
    features.add(FeatureFactory.newNumericalFeature("f-num2", 100.0, NumericalFeatureDomain.create(0.0, 1000.0)));
    features.add(FeatureFactory.newNumericalFeature("f-num3", 100.0, NumericalFeatureDomain.create(0.0, 1000.0)));
    features.add(FeatureFactory.newNumericalFeature("f-num4", 100.0, NumericalFeatureDomain.create(0.0, 1000.0)));
    final double center = 500.0;
    final double epsilon = 10.0;
    final PredictionProvider model = TestUtils.getSumThresholdModel(center, epsilon);
    final CounterfactualResult result = runCounterfactualSearch((long) seed, goal, features, model, DEFAULT_GOAL_THRESHOLD);
    final List<CounterfactualEntity> counterfactualEntities = result.getEntities();
    double totalSum = 0;
    for (CounterfactualEntity entity : counterfactualEntities) {
        totalSum += entity.asFeature().getValue().asNumber();
        logger.debug("Entity: {}", entity);
    }
    assertTrue(totalSum <= center + epsilon);
    assertTrue(totalSum >= center - epsilon);
    final List<Feature> cfFeatures = counterfactualEntities.stream().map(CounterfactualEntity::asFeature).collect(Collectors.toList());
    final PredictionInput cfInput = new PredictionInput(cfFeatures);
    final PredictionOutput cfOutput = model.predictAsync(List.of(cfInput)).get(Config.INSTANCE.getAsyncTimeout(), Config.INSTANCE.getAsyncTimeUnit()).get(0);
    final double predictionScore = cfOutput.getOutputs().get(0).getScore();
    logger.debug("Prediction score: {}", predictionScore);
    assertTrue(predictionScore < 0.5);
    assertTrue(result.isValid());
}
Also used : PredictionInput(org.kie.kogito.explainability.model.PredictionInput) PredictionProvider(org.kie.kogito.explainability.model.PredictionProvider) Feature(org.kie.kogito.explainability.model.Feature) LinkedList(java.util.LinkedList) CounterfactualEntity(org.kie.kogito.explainability.local.counterfactual.entities.CounterfactualEntity) Random(java.util.Random) PredictionOutput(org.kie.kogito.explainability.model.PredictionOutput) PredictionOutput(org.kie.kogito.explainability.model.PredictionOutput) Output(org.kie.kogito.explainability.model.Output) Value(org.kie.kogito.explainability.model.Value) ValueSource(org.junit.jupiter.params.provider.ValueSource) ParameterizedTest(org.junit.jupiter.params.ParameterizedTest)

Example 97 with PredictionProvider

use of org.kie.kogito.explainability.model.PredictionProvider in project kogito-apps by kiegroup.

the class CounterfactualScoreCalculatorTest method testNullBooleanInput.

/**
 * Null values for input Boolean features should be accepted as valid
 */
@Test
void testNullBooleanInput() throws ExecutionException, InterruptedException {
    final CounterFactualScoreCalculator scoreCalculator = new CounterFactualScoreCalculator();
    PredictionProvider model = TestUtils.getFeatureSkipModel(0);
    List<Feature> features = new ArrayList<>();
    List<FeatureDomain> featureDomains = new ArrayList<>();
    List<Boolean> constraints = new ArrayList<>();
    // f-1
    features.add(FeatureFactory.newNumericalFeature("f-1", 1.0));
    featureDomains.add(NumericalFeatureDomain.create(0.0, 10.0));
    constraints.add(false);
    // f-2
    features.add(FeatureFactory.newBooleanFeature("f-2", null));
    featureDomains.add(EmptyFeatureDomain.create());
    constraints.add(false);
    // f-3
    features.add(FeatureFactory.newBooleanFeature("f-3", true));
    featureDomains.add(EmptyFeatureDomain.create());
    constraints.add(false);
    PredictionInput input = new PredictionInput(features);
    PredictionFeatureDomain domains = new PredictionFeatureDomain(featureDomains);
    List<CounterfactualEntity> entities = CounterfactualEntityFactory.createEntities(input);
    List<Output> goal = new ArrayList<>();
    goal.add(new Output("f-2", Type.BOOLEAN, new Value(null), 0.0));
    goal.add(new Output("f-3", Type.BOOLEAN, new Value(true), 0.0));
    final CounterfactualSolution solution = new CounterfactualSolution(entities, features, model, goal, UUID.randomUUID(), UUID.randomUUID(), 0.0);
    BendableBigDecimalScore score = scoreCalculator.calculateScore(solution);
    List<PredictionOutput> predictionOutputs = model.predictAsync(List.of(input)).get();
    assertTrue(score.isFeasible());
    assertEquals(2, goal.size());
    // A single prediction is expected
    assertEquals(1, predictionOutputs.size());
    // Single prediction with two features
    assertEquals(2, predictionOutputs.get(0).getOutputs().size());
    assertEquals(0, score.getHardScore(0).compareTo(BigDecimal.ZERO));
    assertEquals(0, score.getHardScore(1).compareTo(BigDecimal.ZERO));
    assertEquals(0, score.getHardScore(2).compareTo(BigDecimal.ZERO));
    assertEquals(0, score.getSoftScore(0).compareTo(BigDecimal.ZERO));
    assertEquals(0, score.getSoftScore(1).compareTo(BigDecimal.ZERO));
    assertEquals(3, score.getHardLevelsSize());
    assertEquals(2, score.getSoftLevelsSize());
}
Also used : PredictionInput(org.kie.kogito.explainability.model.PredictionInput) ArrayList(java.util.ArrayList) BendableBigDecimalScore(org.optaplanner.core.api.score.buildin.bendablebigdecimal.BendableBigDecimalScore) EmptyFeatureDomain(org.kie.kogito.explainability.model.domain.EmptyFeatureDomain) PredictionFeatureDomain(org.kie.kogito.explainability.model.PredictionFeatureDomain) NumericalFeatureDomain(org.kie.kogito.explainability.model.domain.NumericalFeatureDomain) FeatureDomain(org.kie.kogito.explainability.model.domain.FeatureDomain) PredictionProvider(org.kie.kogito.explainability.model.PredictionProvider) Feature(org.kie.kogito.explainability.model.Feature) CounterfactualEntity(org.kie.kogito.explainability.local.counterfactual.entities.CounterfactualEntity) PredictionFeatureDomain(org.kie.kogito.explainability.model.PredictionFeatureDomain) PredictionOutput(org.kie.kogito.explainability.model.PredictionOutput) PredictionOutput(org.kie.kogito.explainability.model.PredictionOutput) Output(org.kie.kogito.explainability.model.Output) Value(org.kie.kogito.explainability.model.Value) Test(org.junit.jupiter.api.Test) ParameterizedTest(org.junit.jupiter.params.ParameterizedTest)

Example 98 with PredictionProvider

use of org.kie.kogito.explainability.model.PredictionProvider in project kogito-apps by kiegroup.

the class CounterfactualScoreCalculatorTest method testGoalSizeLarger.

/**
 * Using a larger number of features in the goals (3) than the model's output (2) should
 * throw an {@link IllegalArgumentException} with the appropriate message.
 */
@Test
void testGoalSizeLarger() throws ExecutionException, InterruptedException {
    final CounterFactualScoreCalculator scoreCalculator = new CounterFactualScoreCalculator();
    PredictionProvider model = TestUtils.getFeatureSkipModel(0);
    List<Feature> features = new ArrayList<>();
    List<FeatureDomain> featureDomains = new ArrayList<>();
    List<Boolean> constraints = new ArrayList<>();
    // f-1
    features.add(FeatureFactory.newNumericalFeature("f-1", 1.0));
    featureDomains.add(NumericalFeatureDomain.create(0.0, 10.0));
    constraints.add(false);
    // f-2
    features.add(FeatureFactory.newNumericalFeature("f-2", 2.0));
    featureDomains.add(NumericalFeatureDomain.create(0.0, 10.0));
    constraints.add(false);
    // f-3
    features.add(FeatureFactory.newBooleanFeature("f-3", true));
    featureDomains.add(EmptyFeatureDomain.create());
    constraints.add(false);
    PredictionInput input = new PredictionInput(features);
    PredictionFeatureDomain domains = new PredictionFeatureDomain(featureDomains);
    List<CounterfactualEntity> entities = CounterfactualEntityFactory.createEntities(input);
    List<Output> goal = new ArrayList<>();
    goal.add(new Output("f-1", Type.NUMBER, new Value(1.0), 0.0));
    goal.add(new Output("f-2", Type.NUMBER, new Value(2.0), 0.0));
    goal.add(new Output("f-3", Type.BOOLEAN, new Value(true), 0.0));
    List<PredictionOutput> predictionOutputs = model.predictAsync(List.of(input)).get();
    assertEquals(3, goal.size());
    // A single prediction is expected
    assertEquals(1, predictionOutputs.size());
    // Single prediction with two features
    assertEquals(2, predictionOutputs.get(0).getOutputs().size());
    final CounterfactualSolution solution = new CounterfactualSolution(entities, features, model, goal, UUID.randomUUID(), UUID.randomUUID(), 0.0);
    IllegalArgumentException exception = assertThrows(IllegalArgumentException.class, () -> {
        scoreCalculator.calculateScore(solution);
    });
    assertEquals("Prediction size must be equal to goal size", exception.getMessage());
}
Also used : PredictionInput(org.kie.kogito.explainability.model.PredictionInput) ArrayList(java.util.ArrayList) EmptyFeatureDomain(org.kie.kogito.explainability.model.domain.EmptyFeatureDomain) PredictionFeatureDomain(org.kie.kogito.explainability.model.PredictionFeatureDomain) NumericalFeatureDomain(org.kie.kogito.explainability.model.domain.NumericalFeatureDomain) FeatureDomain(org.kie.kogito.explainability.model.domain.FeatureDomain) PredictionProvider(org.kie.kogito.explainability.model.PredictionProvider) Feature(org.kie.kogito.explainability.model.Feature) CounterfactualEntity(org.kie.kogito.explainability.local.counterfactual.entities.CounterfactualEntity) PredictionFeatureDomain(org.kie.kogito.explainability.model.PredictionFeatureDomain) PredictionOutput(org.kie.kogito.explainability.model.PredictionOutput) PredictionOutput(org.kie.kogito.explainability.model.PredictionOutput) Output(org.kie.kogito.explainability.model.Output) Value(org.kie.kogito.explainability.model.Value) Test(org.junit.jupiter.api.Test) ParameterizedTest(org.junit.jupiter.params.ParameterizedTest)

Example 99 with PredictionProvider

use of org.kie.kogito.explainability.model.PredictionProvider in project kogito-apps by kiegroup.

the class CounterfactualScoreCalculatorTest method testPrimarySoftScore.

/**
 * Test precision errors for primary soft score.
 * When the primary soft score is calculated between features with the same numerical
 * value a similarity of 1 is expected. For a large number of features, due to floating point errors this distance may be
 * in some cases slightly larger than 1, which will cause the distance (Math.sqrt(1.0-similarity)) to cause an exception.
 * The score calculation method should not let this should not occur.
 */
@ParameterizedTest
@ValueSource(ints = { 0, 1, 2, 3, 4 })
void testPrimarySoftScore(int seed) {
    final Random random = new Random(seed);
    final List<Feature> features = new ArrayList<>();
    final List<FeatureDomain> featureDomains = new ArrayList<>();
    final List<Boolean> constraints = new ArrayList<>();
    final int nFeatures = 1000;
    // Create a large number of identical features
    for (int n = 0; n < nFeatures; n++) {
        features.add(FeatureFactory.newNumericalFeature("f-" + n, random.nextDouble() * 1e-100));
        featureDomains.add(NumericalFeatureDomain.create(0.0, 10.0));
        constraints.add(false);
    }
    final PredictionInput input = new PredictionInput(features);
    final PredictionFeatureDomain domain = new PredictionFeatureDomain(featureDomains);
    final List<CounterfactualEntity> entities = CounterfactualEntityFactory.createEntities(input);
    // Create score calculator and model
    final CounterFactualScoreCalculator scoreCalculator = new CounterFactualScoreCalculator();
    PredictionProvider model = TestUtils.getFeatureSkipModel(0);
    // Create goal
    final List<Output> goal = new ArrayList<>();
    for (int n = 1; n < nFeatures; n++) {
        goal.add(new Output("f-" + n, Type.NUMBER, features.get(n).getValue(), 1.0));
    }
    final CounterfactualSolution solution = new CounterfactualSolution(entities, features, model, goal, UUID.randomUUID(), UUID.randomUUID(), 0.0);
    final BendableBigDecimalScore score = scoreCalculator.calculateScore(solution);
    assertEquals(0.0, score.getSoftScore(0).doubleValue(), 1e-5);
}
Also used : PredictionInput(org.kie.kogito.explainability.model.PredictionInput) ArrayList(java.util.ArrayList) BendableBigDecimalScore(org.optaplanner.core.api.score.buildin.bendablebigdecimal.BendableBigDecimalScore) EmptyFeatureDomain(org.kie.kogito.explainability.model.domain.EmptyFeatureDomain) PredictionFeatureDomain(org.kie.kogito.explainability.model.PredictionFeatureDomain) NumericalFeatureDomain(org.kie.kogito.explainability.model.domain.NumericalFeatureDomain) FeatureDomain(org.kie.kogito.explainability.model.domain.FeatureDomain) PredictionProvider(org.kie.kogito.explainability.model.PredictionProvider) Feature(org.kie.kogito.explainability.model.Feature) CounterfactualEntity(org.kie.kogito.explainability.local.counterfactual.entities.CounterfactualEntity) PredictionFeatureDomain(org.kie.kogito.explainability.model.PredictionFeatureDomain) Random(java.util.Random) PredictionOutput(org.kie.kogito.explainability.model.PredictionOutput) Output(org.kie.kogito.explainability.model.Output) ValueSource(org.junit.jupiter.params.provider.ValueSource) ParameterizedTest(org.junit.jupiter.params.ParameterizedTest)

Example 100 with PredictionProvider

use of org.kie.kogito.explainability.model.PredictionProvider in project kogito-apps by kiegroup.

the class DummyModelsLimeExplainerTest method testFixedOutput.

@ParameterizedTest
@ValueSource(longs = { 0 })
void testFixedOutput(long seed) throws Exception {
    Random random = new Random();
    List<Feature> features = new LinkedList<>();
    features.add(FeatureFactory.newNumericalFeature("f1", 6));
    features.add(FeatureFactory.newNumericalFeature("f2", 3));
    features.add(FeatureFactory.newNumericalFeature("f3", 5));
    PredictionProvider model = TestUtils.getFixedOutputClassifier();
    PredictionInput input = new PredictionInput(features);
    List<PredictionOutput> outputs = model.predictAsync(List.of(input)).get(Config.INSTANCE.getAsyncTimeout(), Config.INSTANCE.getAsyncTimeUnit());
    Prediction prediction = new SimplePrediction(input, outputs.get(0));
    LimeConfig limeConfig = new LimeConfig().withSamples(10).withPerturbationContext(new PerturbationContext(seed, random, 1));
    LimeExplainer limeExplainer = new LimeExplainer(limeConfig);
    Map<String, Saliency> saliencyMap = limeExplainer.explainAsync(prediction, model).get(Config.INSTANCE.getAsyncTimeout(), Config.INSTANCE.getAsyncTimeUnit());
    for (Saliency saliency : saliencyMap.values()) {
        assertNotNull(saliency);
        List<FeatureImportance> topFeatures = saliency.getTopFeatures(3);
        assertEquals(3, topFeatures.size());
        for (FeatureImportance featureImportance : topFeatures) {
            assertEquals(0, featureImportance.getScore());
        }
        assertEquals(0d, ExplainabilityMetrics.impactScore(model, prediction, topFeatures));
    }
    int topK = 1;
    double minimumPositiveStabilityRate = 0.5;
    double minimumNegativeStabilityRate = 0.5;
    TestUtils.assertLimeStability(model, prediction, limeExplainer, topK, minimumPositiveStabilityRate, minimumNegativeStabilityRate);
    List<PredictionInput> inputs = new ArrayList<>();
    for (int i = 0; i < 100; i++) {
        List<Feature> fs = new LinkedList<>();
        fs.add(TestUtils.getMockedNumericFeature());
        fs.add(TestUtils.getMockedNumericFeature());
        fs.add(TestUtils.getMockedNumericFeature());
        inputs.add(new PredictionInput(fs));
    }
    DataDistribution distribution = new PredictionInputsDataDistribution(inputs);
    int k = 2;
    int chunkSize = 10;
    String decision = "class";
    double precision = ExplainabilityMetrics.getLocalSaliencyPrecision(decision, model, limeExplainer, distribution, k, chunkSize);
    assertThat(precision).isEqualTo(1);
    double recall = ExplainabilityMetrics.getLocalSaliencyRecall(decision, model, limeExplainer, distribution, k, chunkSize);
    assertThat(recall).isEqualTo(1);
    double f1 = ExplainabilityMetrics.getLocalSaliencyF1(decision, model, limeExplainer, distribution, k, chunkSize);
    assertThat(f1).isEqualTo(1);
}
Also used : SimplePrediction(org.kie.kogito.explainability.model.SimplePrediction) PerturbationContext(org.kie.kogito.explainability.model.PerturbationContext) PredictionInput(org.kie.kogito.explainability.model.PredictionInput) Prediction(org.kie.kogito.explainability.model.Prediction) SimplePrediction(org.kie.kogito.explainability.model.SimplePrediction) ArrayList(java.util.ArrayList) Saliency(org.kie.kogito.explainability.model.Saliency) PredictionProvider(org.kie.kogito.explainability.model.PredictionProvider) Feature(org.kie.kogito.explainability.model.Feature) LinkedList(java.util.LinkedList) Random(java.util.Random) FeatureImportance(org.kie.kogito.explainability.model.FeatureImportance) PredictionInputsDataDistribution(org.kie.kogito.explainability.model.PredictionInputsDataDistribution) DataDistribution(org.kie.kogito.explainability.model.DataDistribution) PredictionOutput(org.kie.kogito.explainability.model.PredictionOutput) PredictionInputsDataDistribution(org.kie.kogito.explainability.model.PredictionInputsDataDistribution) ValueSource(org.junit.jupiter.params.provider.ValueSource) ParameterizedTest(org.junit.jupiter.params.ParameterizedTest)

Aggregations

PredictionProvider (org.kie.kogito.explainability.model.PredictionProvider)158 Prediction (org.kie.kogito.explainability.model.Prediction)134 PredictionInput (org.kie.kogito.explainability.model.PredictionInput)134 PredictionOutput (org.kie.kogito.explainability.model.PredictionOutput)126 Test (org.junit.jupiter.api.Test)109 SimplePrediction (org.kie.kogito.explainability.model.SimplePrediction)99 Random (java.util.Random)91 Feature (org.kie.kogito.explainability.model.Feature)76 ArrayList (java.util.ArrayList)73 PerturbationContext (org.kie.kogito.explainability.model.PerturbationContext)69 ParameterizedTest (org.junit.jupiter.params.ParameterizedTest)64 LimeConfig (org.kie.kogito.explainability.local.lime.LimeConfig)59 LimeExplainer (org.kie.kogito.explainability.local.lime.LimeExplainer)54 Output (org.kie.kogito.explainability.model.Output)45 Saliency (org.kie.kogito.explainability.model.Saliency)45 LinkedList (java.util.LinkedList)41 Value (org.kie.kogito.explainability.model.Value)41 List (java.util.List)37 LimeConfigOptimizer (org.kie.kogito.explainability.local.lime.optim.LimeConfigOptimizer)33 ValueSource (org.junit.jupiter.params.provider.ValueSource)32