Search in sources :

Example 26 with LimeExplainer

use of org.kie.kogito.explainability.local.lime.LimeExplainer in project kogito-apps by kiegroup.

the class DefaultLimeOptimizationServiceTest method testNullConfig.

@Test
void testNullConfig() {
    LimeConfigOptimizer optimizer = new LimeConfigOptimizer();
    int max = 1;
    LimeOptimizationService service = new DefaultLimeOptimizationService(optimizer, max);
    assertThat(service.getBestConfigFor(new LimeExplainer())).isNull();
}
Also used : LimeExplainer(org.kie.kogito.explainability.local.lime.LimeExplainer) Test(org.junit.jupiter.api.Test)

Example 27 with LimeExplainer

use of org.kie.kogito.explainability.local.lime.LimeExplainer in project kogito-apps by kiegroup.

the class LimeStabilityScoreCalculator method getStabilityScore.

private BigDecimal getStabilityScore(PredictionProvider model, LimeConfig config, List<Prediction> predictions) {
    double succeededEvaluations = 0;
    int topK = 2;
    BigDecimal stabilityScore = BigDecimal.ZERO;
    LimeExplainer limeExplainer = new LimeExplainer(config);
    for (Prediction prediction : predictions) {
        try {
            LocalSaliencyStability stability = ExplainabilityMetrics.getLocalSaliencyStability(model, prediction, limeExplainer, topK, NUM_RUNS);
            for (String decision : stability.getDecisions()) {
                BigDecimal decisionMarginalScore = getDecisionMarginalScore(stability, decision, topK);
                stabilityScore = stabilityScore.add(decisionMarginalScore);
                succeededEvaluations++;
            }
        } catch (ExecutionException e) {
            LOGGER.error("Saliency stability calculation returned an error {}", e.getMessage());
        } catch (InterruptedException e) {
            LOGGER.error("Interrupted while waiting for saliency stability calculation {}", e.getMessage());
            Thread.currentThread().interrupt();
        } catch (TimeoutException e) {
            LOGGER.error("Timed out while waiting for saliency stability calculation", e);
        }
    }
    if (succeededEvaluations > 0) {
        stabilityScore = stabilityScore.divide(BigDecimal.valueOf(succeededEvaluations), RoundingMode.CEILING);
    }
    return stabilityScore;
}
Also used : LimeExplainer(org.kie.kogito.explainability.local.lime.LimeExplainer) Prediction(org.kie.kogito.explainability.model.Prediction) LocalSaliencyStability(org.kie.kogito.explainability.utils.LocalSaliencyStability) ExecutionException(java.util.concurrent.ExecutionException) BigDecimal(java.math.BigDecimal) TimeoutException(java.util.concurrent.TimeoutException)

Example 28 with LimeExplainer

use of org.kie.kogito.explainability.local.lime.LimeExplainer in project kogito-apps by kiegroup.

the class LimeExplainerProducerTest method produce.

@Test
void produce() {
    LimeExplainerProducer producer = new LimeExplainerProducer(1, 2, 10);
    LimeExplainer limeExplainer = producer.produce();
    assertNotNull(limeExplainer);
    assertEquals(1, limeExplainer.getLimeConfig().getNoOfSamples());
    assertEquals(2, limeExplainer.getLimeConfig().getPerturbationContext().getNoOfPerturbations());
    assertEquals(LimeConfig.DEFAULT_NO_OF_RETRIES, limeExplainer.getLimeConfig().getNoOfRetries());
}
Also used : LimeExplainer(org.kie.kogito.explainability.local.lime.LimeExplainer) Test(org.junit.jupiter.api.Test)

Example 29 with LimeExplainer

use of org.kie.kogito.explainability.local.lime.LimeExplainer in project kogito-apps by kiegroup.

the class LocalExplainerServiceHandlerRegistryTest method setup.

@BeforeEach
@SuppressWarnings("unchecked")
public void setup() {
    LimeExplainer limeExplainer = mock(LimeExplainer.class);
    CounterfactualExplainer counterfactualExplainer = mock(CounterfactualExplainer.class);
    PredictionProviderFactory predictionProviderFactory = mock(PredictionProviderFactory.class);
    limeExplainerServiceHandler = spy(new LimeExplainerServiceHandler(limeExplainer, predictionProviderFactory));
    counterfactualExplainerServiceHandler = spy(new CounterfactualExplainerServiceHandler(counterfactualExplainer, predictionProviderFactory, MAX_RUNNING_TIME_SECONDS));
    predictionProvider = mock(PredictionProvider.class);
    callback = mock(Consumer.class);
    when(predictionProviderFactory.createPredictionProvider(any(), any(), any())).thenReturn(predictionProvider);
    Instance<LocalExplainerServiceHandler<?, ?>> explanationHandlers = mock(Instance.class);
    when(explanationHandlers.stream()).thenReturn(Stream.of(limeExplainerServiceHandler, counterfactualExplainerServiceHandler));
    registry = new LocalExplainerServiceHandlerRegistry(explanationHandlers);
}
Also used : Consumer(java.util.function.Consumer) LimeExplainer(org.kie.kogito.explainability.local.lime.LimeExplainer) CounterfactualExplainer(org.kie.kogito.explainability.local.counterfactual.CounterfactualExplainer) PredictionProvider(org.kie.kogito.explainability.model.PredictionProvider) PredictionProviderFactory(org.kie.kogito.explainability.PredictionProviderFactory) BeforeEach(org.junit.jupiter.api.BeforeEach)

Example 30 with LimeExplainer

use of org.kie.kogito.explainability.local.lime.LimeExplainer in project kogito-apps by kiegroup.

the class TrafficViolationDmnLimeExplainerTest method testTrafficViolationDMNExplanation.

@Test
void testTrafficViolationDMNExplanation() throws ExecutionException, InterruptedException, TimeoutException {
    PredictionProvider model = getModel();
    PredictionInput predictionInput = getTestInput();
    List<PredictionOutput> predictionOutputs = model.predictAsync(List.of(predictionInput)).get(Config.INSTANCE.getAsyncTimeout(), Config.INSTANCE.getAsyncTimeUnit());
    Prediction prediction = new SimplePrediction(predictionInput, predictionOutputs.get(0));
    Random random = new Random();
    PerturbationContext perturbationContext = new PerturbationContext(0L, random, 1);
    LimeConfig limeConfig = new LimeConfig().withSamples(10).withPerturbationContext(perturbationContext);
    LimeExplainer limeExplainer = new LimeExplainer(limeConfig);
    Map<String, Saliency> saliencyMap = limeExplainer.explainAsync(prediction, model).get(Config.INSTANCE.getAsyncTimeout(), Config.INSTANCE.getAsyncTimeUnit());
    for (Saliency saliency : saliencyMap.values()) {
        assertNotNull(saliency);
        List<String> strings = saliency.getTopFeatures(3).stream().map(f -> f.getFeature().getName()).collect(Collectors.toList());
        assertTrue(strings.contains("Actual Speed") || strings.contains("Speed Limit"));
    }
    assertDoesNotThrow(() -> ValidationUtils.validateLocalSaliencyStability(model, prediction, limeExplainer, 1, 0.3, 0.3));
    String decision = "Fine";
    List<PredictionInput> inputs = new ArrayList<>();
    for (int n = 0; n < 10; n++) {
        inputs.add(new PredictionInput(DataUtils.perturbFeatures(predictionInput.getFeatures(), perturbationContext)));
    }
    DataDistribution distribution = new PredictionInputsDataDistribution(inputs);
    int k = 2;
    int chunkSize = 5;
    double f1 = ExplainabilityMetrics.getLocalSaliencyF1(decision, model, limeExplainer, distribution, k, chunkSize);
    AssertionsForClassTypes.assertThat(f1).isBetween(0.5d, 1d);
}
Also used : FeatureFactory(org.kie.kogito.explainability.model.FeatureFactory) Assertions.assertNotNull(org.junit.jupiter.api.Assertions.assertNotNull) DecisionModel(org.kie.kogito.decision.DecisionModel) PredictionInputsDataDistribution(org.kie.kogito.explainability.model.PredictionInputsDataDistribution) PerturbationContext(org.kie.kogito.explainability.model.PerturbationContext) Feature(org.kie.kogito.explainability.model.Feature) Prediction(org.kie.kogito.explainability.model.Prediction) Assertions.assertThat(org.assertj.core.api.Assertions.assertThat) AssertionsForClassTypes(org.assertj.core.api.AssertionsForClassTypes) TimeoutException(java.util.concurrent.TimeoutException) DmnDecisionModel(org.kie.kogito.dmn.DmnDecisionModel) HashMap(java.util.HashMap) Random(java.util.Random) DataDistribution(org.kie.kogito.explainability.model.DataDistribution) Saliency(org.kie.kogito.explainability.model.Saliency) ArrayList(java.util.ArrayList) Map(java.util.Map) LimeConfig(org.kie.kogito.explainability.local.lime.LimeConfig) DMNRuntime(org.kie.dmn.api.core.DMNRuntime) Assertions.assertEquals(org.junit.jupiter.api.Assertions.assertEquals) LinkedList(java.util.LinkedList) PredictionOutput(org.kie.kogito.explainability.model.PredictionOutput) LimeConfigOptimizer(org.kie.kogito.explainability.local.lime.optim.LimeConfigOptimizer) SimplePrediction(org.kie.kogito.explainability.model.SimplePrediction) LimeExplainer(org.kie.kogito.explainability.local.lime.LimeExplainer) DataUtils(org.kie.kogito.explainability.utils.DataUtils) InputStreamReader(java.io.InputStreamReader) Collectors(java.util.stream.Collectors) DMNKogito(org.kie.kogito.dmn.DMNKogito) PredictionProvider(org.kie.kogito.explainability.model.PredictionProvider) ExecutionException(java.util.concurrent.ExecutionException) Test(org.junit.jupiter.api.Test) PredictionInput(org.kie.kogito.explainability.model.PredictionInput) List(java.util.List) ExplainabilityMetrics(org.kie.kogito.explainability.utils.ExplainabilityMetrics) Assertions.assertTrue(org.junit.jupiter.api.Assertions.assertTrue) ValidationUtils(org.kie.kogito.explainability.utils.ValidationUtils) Config(org.kie.kogito.explainability.Config) Assertions.assertDoesNotThrow(org.junit.jupiter.api.Assertions.assertDoesNotThrow) SimplePrediction(org.kie.kogito.explainability.model.SimplePrediction) PerturbationContext(org.kie.kogito.explainability.model.PerturbationContext) PredictionInput(org.kie.kogito.explainability.model.PredictionInput) LimeExplainer(org.kie.kogito.explainability.local.lime.LimeExplainer) Prediction(org.kie.kogito.explainability.model.Prediction) SimplePrediction(org.kie.kogito.explainability.model.SimplePrediction) ArrayList(java.util.ArrayList) Saliency(org.kie.kogito.explainability.model.Saliency) PredictionProvider(org.kie.kogito.explainability.model.PredictionProvider) LimeConfig(org.kie.kogito.explainability.local.lime.LimeConfig) Random(java.util.Random) PredictionInputsDataDistribution(org.kie.kogito.explainability.model.PredictionInputsDataDistribution) DataDistribution(org.kie.kogito.explainability.model.DataDistribution) PredictionOutput(org.kie.kogito.explainability.model.PredictionOutput) PredictionInputsDataDistribution(org.kie.kogito.explainability.model.PredictionInputsDataDistribution) Test(org.junit.jupiter.api.Test)

Aggregations

LimeExplainer (org.kie.kogito.explainability.local.lime.LimeExplainer)42 Prediction (org.kie.kogito.explainability.model.Prediction)38 Test (org.junit.jupiter.api.Test)37 LimeConfig (org.kie.kogito.explainability.local.lime.LimeConfig)36 PredictionProvider (org.kie.kogito.explainability.model.PredictionProvider)36 SimplePrediction (org.kie.kogito.explainability.model.SimplePrediction)35 PredictionInput (org.kie.kogito.explainability.model.PredictionInput)34 Random (java.util.Random)33 PerturbationContext (org.kie.kogito.explainability.model.PerturbationContext)33 PredictionOutput (org.kie.kogito.explainability.model.PredictionOutput)33 LimeConfigOptimizer (org.kie.kogito.explainability.local.lime.optim.LimeConfigOptimizer)19 Saliency (org.kie.kogito.explainability.model.Saliency)16 PredictionInputsDataDistribution (org.kie.kogito.explainability.model.PredictionInputsDataDistribution)13 DataDistribution (org.kie.kogito.explainability.model.DataDistribution)12 ArrayList (java.util.ArrayList)9 Feature (org.kie.kogito.explainability.model.Feature)7 ExecutionException (java.util.concurrent.ExecutionException)5 TimeoutException (java.util.concurrent.TimeoutException)5 FeatureImportance (org.kie.kogito.explainability.model.FeatureImportance)5 InputStreamReader (java.io.InputStreamReader)4