Search in sources :

Example 16 with CounterfactualPrediction

use of org.kie.kogito.explainability.model.CounterfactualPrediction in project kogito-apps by kiegroup.

the class CounterfactualExplainerServiceHandlerTest method testGetPredictionWithFlatSearchDomainsNotFixed.

@Test
public void testGetPredictionWithFlatSearchDomainsNotFixed() {
    CounterfactualExplainabilityRequest request = new CounterfactualExplainabilityRequest(EXECUTION_ID, SERVICE_URL, MODEL_IDENTIFIER, COUNTERFACTUAL_ID, List.of(new NamedTypedValue("output1", new UnitValue("number", new IntNode(25)))), Collections.emptyList(), List.of(new CounterfactualSearchDomain("output1", new CounterfactualSearchDomainUnitValue("number", "number", false, new CounterfactualDomainRange(new IntNode(10), new IntNode(20))))), MAX_RUNNING_TIME_SECONDS);
    Prediction prediction = handler.getPrediction(request);
    assertTrue(prediction instanceof CounterfactualPrediction);
    CounterfactualPrediction counterfactualPrediction = (CounterfactualPrediction) prediction;
    assertEquals(1, counterfactualPrediction.getInput().getFeatures().size());
    Feature feature1 = counterfactualPrediction.getInput().getFeatures().get(0);
    assertTrue(feature1.getDomain() instanceof NumericalFeatureDomain);
    final NumericalFeatureDomain domain = (NumericalFeatureDomain) feature1.getDomain();
    assertEquals(10, domain.getLowerBound());
    assertEquals(20, domain.getUpperBound());
    assertEquals(counterfactualPrediction.getMaxRunningTimeSeconds(), request.getMaxRunningTimeSeconds());
}
Also used : CounterfactualExplainabilityRequest(org.kie.kogito.explainability.api.CounterfactualExplainabilityRequest) NamedTypedValue(org.kie.kogito.explainability.api.NamedTypedValue) IntNode(com.fasterxml.jackson.databind.node.IntNode) CounterfactualDomainRange(org.kie.kogito.explainability.api.CounterfactualDomainRange) Prediction(org.kie.kogito.explainability.model.Prediction) CounterfactualPrediction(org.kie.kogito.explainability.model.CounterfactualPrediction) UnitValue(org.kie.kogito.tracing.typedvalue.UnitValue) CounterfactualSearchDomainUnitValue(org.kie.kogito.explainability.api.CounterfactualSearchDomainUnitValue) NumericalFeatureDomain(org.kie.kogito.explainability.model.domain.NumericalFeatureDomain) CounterfactualSearchDomainUnitValue(org.kie.kogito.explainability.api.CounterfactualSearchDomainUnitValue) Feature(org.kie.kogito.explainability.model.Feature) CounterfactualSearchDomain(org.kie.kogito.explainability.api.CounterfactualSearchDomain) CounterfactualPrediction(org.kie.kogito.explainability.model.CounterfactualPrediction) Test(org.junit.jupiter.api.Test)

Example 17 with CounterfactualPrediction

use of org.kie.kogito.explainability.model.CounterfactualPrediction in project kogito-apps by kiegroup.

the class CounterfactualExplainerServiceHandlerTest method testGetPredictionWithFlatSearchDomainsFixed.

@Test
public void testGetPredictionWithFlatSearchDomainsFixed() {
    CounterfactualExplainabilityRequest request = new CounterfactualExplainabilityRequest(EXECUTION_ID, SERVICE_URL, MODEL_IDENTIFIER, COUNTERFACTUAL_ID, List.of(new NamedTypedValue("output1", new UnitValue("number", new IntNode(25)))), Collections.emptyList(), List.of(new CounterfactualSearchDomain("output1", new CounterfactualSearchDomainUnitValue("number", "number", true, new CounterfactualDomainRange(new IntNode(10), new IntNode(20))))), MAX_RUNNING_TIME_SECONDS);
    Prediction prediction = handler.getPrediction(request);
    assertTrue(prediction instanceof CounterfactualPrediction);
    CounterfactualPrediction counterfactualPrediction = (CounterfactualPrediction) prediction;
    assertEquals(1, counterfactualPrediction.getInput().getFeatures().size());
    Feature feature1 = counterfactualPrediction.getInput().getFeatures().get(0);
    assertTrue(feature1.getDomain() instanceof EmptyFeatureDomain);
    assertEquals(counterfactualPrediction.getMaxRunningTimeSeconds(), request.getMaxRunningTimeSeconds());
}
Also used : CounterfactualExplainabilityRequest(org.kie.kogito.explainability.api.CounterfactualExplainabilityRequest) NamedTypedValue(org.kie.kogito.explainability.api.NamedTypedValue) IntNode(com.fasterxml.jackson.databind.node.IntNode) CounterfactualDomainRange(org.kie.kogito.explainability.api.CounterfactualDomainRange) Prediction(org.kie.kogito.explainability.model.Prediction) CounterfactualPrediction(org.kie.kogito.explainability.model.CounterfactualPrediction) EmptyFeatureDomain(org.kie.kogito.explainability.model.domain.EmptyFeatureDomain) UnitValue(org.kie.kogito.tracing.typedvalue.UnitValue) CounterfactualSearchDomainUnitValue(org.kie.kogito.explainability.api.CounterfactualSearchDomainUnitValue) CounterfactualSearchDomainUnitValue(org.kie.kogito.explainability.api.CounterfactualSearchDomainUnitValue) Feature(org.kie.kogito.explainability.model.Feature) CounterfactualSearchDomain(org.kie.kogito.explainability.api.CounterfactualSearchDomain) CounterfactualPrediction(org.kie.kogito.explainability.model.CounterfactualPrediction) Test(org.junit.jupiter.api.Test)

Example 18 with CounterfactualPrediction

use of org.kie.kogito.explainability.model.CounterfactualPrediction in project kogito-apps by kiegroup.

the class CounterfactualExplainerServiceHandlerTest method testGetPredictionWithEmptyDefinition.

@Test
public void testGetPredictionWithEmptyDefinition() {
    CounterfactualExplainabilityRequest request = new CounterfactualExplainabilityRequest(EXECUTION_ID, SERVICE_URL, MODEL_IDENTIFIER, COUNTERFACTUAL_ID, Collections.emptyList(), Collections.emptyList(), Collections.emptyList(), MAX_RUNNING_TIME_SECONDS);
    Prediction prediction = handler.getPrediction(request);
    assertTrue(prediction instanceof CounterfactualPrediction);
    CounterfactualPrediction counterfactualPrediction = (CounterfactualPrediction) prediction;
    assertTrue(counterfactualPrediction.getInput().getFeatures().isEmpty());
    assertEquals(counterfactualPrediction.getMaxRunningTimeSeconds(), request.getMaxRunningTimeSeconds());
}
Also used : CounterfactualExplainabilityRequest(org.kie.kogito.explainability.api.CounterfactualExplainabilityRequest) Prediction(org.kie.kogito.explainability.model.Prediction) CounterfactualPrediction(org.kie.kogito.explainability.model.CounterfactualPrediction) CounterfactualPrediction(org.kie.kogito.explainability.model.CounterfactualPrediction) Test(org.junit.jupiter.api.Test)

Example 19 with CounterfactualPrediction

use of org.kie.kogito.explainability.model.CounterfactualPrediction in project kogito-apps by kiegroup.

the class PrequalificationDmnCounterfactualExplainerTest method testValidCounterfactual.

@Test
void testValidCounterfactual() throws ExecutionException, InterruptedException, TimeoutException {
    PredictionProvider model = getModel();
    final List<Output> goal = List.of(new Output("Qualified?", Type.BOOLEAN, new Value(true), 0.0d));
    final TerminationConfig terminationConfig = new TerminationConfig().withScoreCalculationCountLimit(steps);
    final SolverConfig solverConfig = SolverConfigBuilder.builder().withTerminationConfig(terminationConfig).build();
    solverConfig.setRandomSeed(randomSeed);
    solverConfig.setEnvironmentMode(EnvironmentMode.REPRODUCIBLE);
    CounterfactualConfig config = new CounterfactualConfig().withGoalThreshold(0.1);
    config.withSolverConfig(solverConfig);
    final CounterfactualExplainer explainer = new CounterfactualExplainer(config);
    PredictionInput input = getTestInputVariable();
    PredictionOutput output = new PredictionOutput(goal);
    // test model
    List<PredictionOutput> predictionOutputs = model.predictAsync(List.of(getTestInputFixed())).get(Config.INSTANCE.getAsyncTimeout(), Config.INSTANCE.getAsyncTimeUnit());
    final Output predictionOutput = predictionOutputs.get(0).getOutputs().get(0);
    assertEquals("Qualified?", predictionOutput.getName());
    assertFalse((Boolean) predictionOutput.getValue().getUnderlyingObject());
    Prediction prediction = new CounterfactualPrediction(input, output, null, UUID.randomUUID(), null);
    CounterfactualResult counterfactualResult = explainer.explainAsync(prediction, model).get();
    List<Feature> cfFeatures = counterfactualResult.getEntities().stream().map(CounterfactualEntity::asFeature).collect(Collectors.toList());
    List<Feature> unflattened = CompositeFeatureUtils.unflattenFeatures(cfFeatures, input.getFeatures());
    List<PredictionOutput> outputs = model.predictAsync(List.of(new PredictionInput(unflattened))).get();
    assertTrue(counterfactualResult.isValid());
    final Output decideOutput = outputs.get(0).getOutputs().get(0);
    assertEquals("Qualified?", decideOutput.getName());
    assertTrue((Boolean) decideOutput.getValue().getUnderlyingObject());
}
Also used : PredictionInput(org.kie.kogito.explainability.model.PredictionInput) Prediction(org.kie.kogito.explainability.model.Prediction) CounterfactualPrediction(org.kie.kogito.explainability.model.CounterfactualPrediction) PredictionProvider(org.kie.kogito.explainability.model.PredictionProvider) Feature(org.kie.kogito.explainability.model.Feature) CounterfactualResult(org.kie.kogito.explainability.local.counterfactual.CounterfactualResult) CounterfactualPrediction(org.kie.kogito.explainability.model.CounterfactualPrediction) TerminationConfig(org.optaplanner.core.config.solver.termination.TerminationConfig) PredictionOutput(org.kie.kogito.explainability.model.PredictionOutput) PredictionOutput(org.kie.kogito.explainability.model.PredictionOutput) Output(org.kie.kogito.explainability.model.Output) CounterfactualConfig(org.kie.kogito.explainability.local.counterfactual.CounterfactualConfig) Value(org.kie.kogito.explainability.model.Value) CounterfactualExplainer(org.kie.kogito.explainability.local.counterfactual.CounterfactualExplainer) SolverConfig(org.optaplanner.core.config.solver.SolverConfig) Test(org.junit.jupiter.api.Test)

Aggregations

CounterfactualPrediction (org.kie.kogito.explainability.model.CounterfactualPrediction)19 Prediction (org.kie.kogito.explainability.model.Prediction)18 PredictionOutput (org.kie.kogito.explainability.model.PredictionOutput)16 Feature (org.kie.kogito.explainability.model.Feature)14 Test (org.junit.jupiter.api.Test)13 Output (org.kie.kogito.explainability.model.Output)13 PredictionInput (org.kie.kogito.explainability.model.PredictionInput)13 PredictionProvider (org.kie.kogito.explainability.model.PredictionProvider)12 SolverConfig (org.optaplanner.core.config.solver.SolverConfig)11 TerminationConfig (org.optaplanner.core.config.solver.termination.TerminationConfig)10 CounterfactualEntity (org.kie.kogito.explainability.local.counterfactual.entities.CounterfactualEntity)9 Value (org.kie.kogito.explainability.model.Value)8 LinkedList (java.util.LinkedList)7 CounterfactualExplainer (org.kie.kogito.explainability.local.counterfactual.CounterfactualExplainer)7 CounterfactualResult (org.kie.kogito.explainability.local.counterfactual.CounterfactualResult)7 UUID (java.util.UUID)6 List (java.util.List)5 Consumer (java.util.function.Consumer)5 Collectors (java.util.stream.Collectors)5 CounterfactualExplainabilityRequest (org.kie.kogito.explainability.api.CounterfactualExplainabilityRequest)5