Search in sources :

Example 6 with CounterfactualExplainer

use of org.kie.kogito.explainability.local.counterfactual.CounterfactualExplainer in project kogito-apps by kiegroup.

the class LocalExplainerServiceHandlerRegistryTest method setup.

@BeforeEach
@SuppressWarnings("unchecked")
public void setup() {
    LimeExplainer limeExplainer = mock(LimeExplainer.class);
    CounterfactualExplainer counterfactualExplainer = mock(CounterfactualExplainer.class);
    PredictionProviderFactory predictionProviderFactory = mock(PredictionProviderFactory.class);
    limeExplainerServiceHandler = spy(new LimeExplainerServiceHandler(limeExplainer, predictionProviderFactory));
    counterfactualExplainerServiceHandler = spy(new CounterfactualExplainerServiceHandler(counterfactualExplainer, predictionProviderFactory, MAX_RUNNING_TIME_SECONDS));
    predictionProvider = mock(PredictionProvider.class);
    callback = mock(Consumer.class);
    when(predictionProviderFactory.createPredictionProvider(any(), any(), any())).thenReturn(predictionProvider);
    Instance<LocalExplainerServiceHandler<?, ?>> explanationHandlers = mock(Instance.class);
    when(explanationHandlers.stream()).thenReturn(Stream.of(limeExplainerServiceHandler, counterfactualExplainerServiceHandler));
    registry = new LocalExplainerServiceHandlerRegistry(explanationHandlers);
}
Also used : Consumer(java.util.function.Consumer) LimeExplainer(org.kie.kogito.explainability.local.lime.LimeExplainer) CounterfactualExplainer(org.kie.kogito.explainability.local.counterfactual.CounterfactualExplainer) PredictionProvider(org.kie.kogito.explainability.model.PredictionProvider) PredictionProviderFactory(org.kie.kogito.explainability.PredictionProviderFactory) BeforeEach(org.junit.jupiter.api.BeforeEach)

Example 7 with CounterfactualExplainer

use of org.kie.kogito.explainability.local.counterfactual.CounterfactualExplainer in project kogito-apps by kiegroup.

the class CounterfactualExplainerProducer method produce.

@Produces
public CounterfactualExplainer produce() {
    LOG.debug("CounterfactualExplainer created");
    final CounterfactualConfig counterfactualConfig = new CounterfactualConfig().withGoalThreshold(this.goalThreshold).withExecutor(executor);
    return new CounterfactualExplainer(counterfactualConfig);
}
Also used : CounterfactualConfig(org.kie.kogito.explainability.local.counterfactual.CounterfactualConfig) CounterfactualExplainer(org.kie.kogito.explainability.local.counterfactual.CounterfactualExplainer) Produces(javax.enterprise.inject.Produces)

Example 8 with CounterfactualExplainer

use of org.kie.kogito.explainability.local.counterfactual.CounterfactualExplainer in project kogito-apps by kiegroup.

the class PrequalificationDmnCounterfactualExplainerTest method testValidCounterfactual.

@Test
void testValidCounterfactual() throws ExecutionException, InterruptedException, TimeoutException {
    PredictionProvider model = getModel();
    final List<Output> goal = List.of(new Output("Qualified?", Type.BOOLEAN, new Value(true), 0.0d));
    final TerminationConfig terminationConfig = new TerminationConfig().withScoreCalculationCountLimit(steps);
    final SolverConfig solverConfig = SolverConfigBuilder.builder().withTerminationConfig(terminationConfig).build();
    solverConfig.setRandomSeed(randomSeed);
    solverConfig.setEnvironmentMode(EnvironmentMode.REPRODUCIBLE);
    CounterfactualConfig config = new CounterfactualConfig().withGoalThreshold(0.1);
    config.withSolverConfig(solverConfig);
    final CounterfactualExplainer explainer = new CounterfactualExplainer(config);
    PredictionInput input = getTestInputVariable();
    PredictionOutput output = new PredictionOutput(goal);
    // test model
    List<PredictionOutput> predictionOutputs = model.predictAsync(List.of(getTestInputFixed())).get(Config.INSTANCE.getAsyncTimeout(), Config.INSTANCE.getAsyncTimeUnit());
    final Output predictionOutput = predictionOutputs.get(0).getOutputs().get(0);
    assertEquals("Qualified?", predictionOutput.getName());
    assertFalse((Boolean) predictionOutput.getValue().getUnderlyingObject());
    Prediction prediction = new CounterfactualPrediction(input, output, null, UUID.randomUUID(), null);
    CounterfactualResult counterfactualResult = explainer.explainAsync(prediction, model).get();
    List<Feature> cfFeatures = counterfactualResult.getEntities().stream().map(CounterfactualEntity::asFeature).collect(Collectors.toList());
    List<Feature> unflattened = CompositeFeatureUtils.unflattenFeatures(cfFeatures, input.getFeatures());
    List<PredictionOutput> outputs = model.predictAsync(List.of(new PredictionInput(unflattened))).get();
    assertTrue(counterfactualResult.isValid());
    final Output decideOutput = outputs.get(0).getOutputs().get(0);
    assertEquals("Qualified?", decideOutput.getName());
    assertTrue((Boolean) decideOutput.getValue().getUnderlyingObject());
}
Also used : PredictionInput(org.kie.kogito.explainability.model.PredictionInput) Prediction(org.kie.kogito.explainability.model.Prediction) CounterfactualPrediction(org.kie.kogito.explainability.model.CounterfactualPrediction) PredictionProvider(org.kie.kogito.explainability.model.PredictionProvider) Feature(org.kie.kogito.explainability.model.Feature) CounterfactualResult(org.kie.kogito.explainability.local.counterfactual.CounterfactualResult) CounterfactualPrediction(org.kie.kogito.explainability.model.CounterfactualPrediction) TerminationConfig(org.optaplanner.core.config.solver.termination.TerminationConfig) PredictionOutput(org.kie.kogito.explainability.model.PredictionOutput) PredictionOutput(org.kie.kogito.explainability.model.PredictionOutput) Output(org.kie.kogito.explainability.model.Output) CounterfactualConfig(org.kie.kogito.explainability.local.counterfactual.CounterfactualConfig) Value(org.kie.kogito.explainability.model.Value) CounterfactualExplainer(org.kie.kogito.explainability.local.counterfactual.CounterfactualExplainer) SolverConfig(org.optaplanner.core.config.solver.SolverConfig) Test(org.junit.jupiter.api.Test)

Aggregations

CounterfactualExplainer (org.kie.kogito.explainability.local.counterfactual.CounterfactualExplainer)8 Test (org.junit.jupiter.api.Test)6 CounterfactualConfig (org.kie.kogito.explainability.local.counterfactual.CounterfactualConfig)6 PredictionProvider (org.kie.kogito.explainability.model.PredictionProvider)6 CounterfactualResult (org.kie.kogito.explainability.local.counterfactual.CounterfactualResult)5 CounterfactualPrediction (org.kie.kogito.explainability.model.CounterfactualPrediction)5 Feature (org.kie.kogito.explainability.model.Feature)5 Output (org.kie.kogito.explainability.model.Output)5 Prediction (org.kie.kogito.explainability.model.Prediction)5 PredictionInput (org.kie.kogito.explainability.model.PredictionInput)5 PredictionOutput (org.kie.kogito.explainability.model.PredictionOutput)5 SolverConfig (org.optaplanner.core.config.solver.SolverConfig)5 TerminationConfig (org.optaplanner.core.config.solver.termination.TerminationConfig)5 LinkedList (java.util.LinkedList)3 CounterfactualEntity (org.kie.kogito.explainability.local.counterfactual.entities.CounterfactualEntity)2 Value (org.kie.kogito.explainability.model.Value)2 SmallRyeManagedExecutor (io.smallrye.context.SmallRyeManagedExecutor)1 Consumer (java.util.function.Consumer)1 Produces (javax.enterprise.inject.Produces)1 ManagedExecutor (org.eclipse.microprofile.context.ManagedExecutor)1