Search in sources :

Example 1 with BendableBigDecimalScore

use of org.optaplanner.core.api.score.buildin.bendablebigdecimal.BendableBigDecimalScore in project kogito-apps by kiegroup.

the class CounterfactualExplainerTest method mockExplainerInvocation.

@SuppressWarnings("unchecked")
CounterfactualResult mockExplainerInvocation(Consumer<CounterfactualResult> intermediateResultsConsumer, Long maxRunningTimeSeconds) throws ExecutionException, InterruptedException, TimeoutException {
    // Mock SolverManager and SolverJob to guarantee deterministic test behaviour
    SolverJob<CounterfactualSolution, UUID> solverJob = mock(SolverJob.class);
    CounterfactualSolution solution = mock(CounterfactualSolution.class);
    BendableBigDecimalScore score = BendableBigDecimalScore.zero(0, 0);
    when(solverManager.solveAndListen(any(), any(), any(), any())).thenReturn(solverJob);
    when(solverJob.getFinalBestSolution()).thenReturn(solution);
    when(solution.getScore()).thenReturn(score);
    when(solverManagerFactory.apply(any())).thenReturn(solverManager);
    // Setup Explainer
    final CounterfactualConfig counterfactualConfig = new CounterfactualConfig().withSolverManagerFactory(solverManagerFactory);
    final CounterfactualExplainer counterfactualExplainer = new CounterfactualExplainer(counterfactualConfig);
    // Setup mock model, what it does is not important
    Prediction prediction = new CounterfactualPrediction(new PredictionInput(Collections.emptyList()), new PredictionOutput(Collections.emptyList()), null, UUID.randomUUID(), maxRunningTimeSeconds);
    return counterfactualExplainer.explainAsync(prediction, (List<PredictionInput> inputs) -> CompletableFuture.completedFuture(Collections.emptyList()), intermediateResultsConsumer).get(Config.INSTANCE.getAsyncTimeout(), Config.INSTANCE.getAsyncTimeUnit());
}
Also used : PredictionInput(org.kie.kogito.explainability.model.PredictionInput) PredictionOutput(org.kie.kogito.explainability.model.PredictionOutput) Prediction(org.kie.kogito.explainability.model.Prediction) CounterfactualPrediction(org.kie.kogito.explainability.model.CounterfactualPrediction) BendableBigDecimalScore(org.optaplanner.core.api.score.buildin.bendablebigdecimal.BendableBigDecimalScore) UUID(java.util.UUID) CounterfactualPrediction(org.kie.kogito.explainability.model.CounterfactualPrediction)

Example 2 with BendableBigDecimalScore

use of org.optaplanner.core.api.score.buildin.bendablebigdecimal.BendableBigDecimalScore in project kogito-apps by kiegroup.

the class CounterfactualExplainerTest method testSequenceIds.

@ParameterizedTest
@ValueSource(ints = { 1, 2, 3, 5, 8 })
@SuppressWarnings("unchecked")
void testSequenceIds(int numberOfIntermediateSolutions) throws ExecutionException, InterruptedException, TimeoutException {
    final List<Long> sequenceIds = new ArrayList<>();
    final Consumer<CounterfactualResult> captureSequenceIds = counterfactual -> {
        sequenceIds.add(counterfactual.getSequenceId());
    };
    ArgumentCaptor<Consumer<CounterfactualSolution>> intermediateSolutionConsumerCaptor = ArgumentCaptor.forClass(Consumer.class);
    CounterfactualResult result = mockExplainerInvocation(captureSequenceIds, null);
    verify(solverManager).solveAndListen(any(), any(), intermediateSolutionConsumerCaptor.capture(), any());
    Consumer<CounterfactualSolution> intermediateSolutionConsumer = intermediateSolutionConsumerCaptor.getValue();
    // Mock the intermediate Solution callback being invoked
    IntStream.range(0, numberOfIntermediateSolutions).forEach(i -> {
        CounterfactualSolution intermediate = mock(CounterfactualSolution.class);
        BendableBigDecimalScore intermediateScore = BendableBigDecimalScore.zero(0, 0);
        when(intermediate.getScore()).thenReturn(intermediateScore);
        intermediateSolutionConsumer.accept(intermediate);
    });
    // The final and intermediate Solutions should all have unique Sequence Ids.
    sequenceIds.add(result.getSequenceId());
    assertEquals(numberOfIntermediateSolutions + 1, sequenceIds.size());
    assertEquals(numberOfIntermediateSolutions + 1, (int) sequenceIds.stream().distinct().count());
}
Also used : BeforeEach(org.junit.jupiter.api.BeforeEach) FeatureFactory(org.kie.kogito.explainability.model.FeatureFactory) Feature(org.kie.kogito.explainability.model.Feature) LoggerFactory(org.slf4j.LoggerFactory) Assertions.assertNotEquals(org.junit.jupiter.api.Assertions.assertNotEquals) TimeoutException(java.util.concurrent.TimeoutException) Random(java.util.Random) CounterfactualEntity(org.kie.kogito.explainability.local.counterfactual.entities.CounterfactualEntity) Value(org.kie.kogito.explainability.model.Value) TerminationConfig(org.optaplanner.core.config.solver.termination.TerminationConfig) Assertions.assertFalse(org.junit.jupiter.api.Assertions.assertFalse) FeatureDistribution(org.kie.kogito.explainability.model.FeatureDistribution) EmptyFeatureDomain(org.kie.kogito.explainability.model.domain.EmptyFeatureDomain) Mockito.atLeast(org.mockito.Mockito.atLeast) PredictionOutput(org.kie.kogito.explainability.model.PredictionOutput) CategoricalFeatureDomain(org.kie.kogito.explainability.model.domain.CategoricalFeatureDomain) DataUtils(org.kie.kogito.explainability.utils.DataUtils) UUID(java.util.UUID) Collectors(java.util.stream.Collectors) Test(org.junit.jupiter.api.Test) PredictionInput(org.kie.kogito.explainability.model.PredictionInput) List(java.util.List) Stream(java.util.stream.Stream) NormalDistribution(org.apache.commons.math3.distribution.NormalDistribution) Output(org.kie.kogito.explainability.model.Output) Assertions.assertTrue(org.junit.jupiter.api.Assertions.assertTrue) SolverJob(org.optaplanner.core.api.solver.SolverJob) Mockito.mock(org.mockito.Mockito.mock) IntStream(java.util.stream.IntStream) ArgumentMatchers.any(org.mockito.ArgumentMatchers.any) SolverConfig(org.optaplanner.core.config.solver.SolverConfig) Assertions.assertNotNull(org.junit.jupiter.api.Assertions.assertNotNull) PerturbationContext(org.kie.kogito.explainability.model.PerturbationContext) Prediction(org.kie.kogito.explainability.model.Prediction) DataDomain(org.kie.kogito.explainability.model.DataDomain) Assertions.assertNull(org.junit.jupiter.api.Assertions.assertNull) EnvironmentMode(org.optaplanner.core.config.solver.EnvironmentMode) CompletableFuture(java.util.concurrent.CompletableFuture) SolverManager(org.optaplanner.core.api.solver.SolverManager) Function(java.util.function.Function) ArrayList(java.util.ArrayList) MockCounterFactualScoreCalculator(org.kie.kogito.explainability.local.counterfactual.score.MockCounterFactualScoreCalculator) ArgumentCaptor(org.mockito.ArgumentCaptor) NumericFeatureDistribution(org.kie.kogito.explainability.model.NumericFeatureDistribution) CounterfactualPrediction(org.kie.kogito.explainability.model.CounterfactualPrediction) Assertions.assertEquals(org.junit.jupiter.api.Assertions.assertEquals) LinkedList(java.util.LinkedList) BendableBigDecimalScore(org.optaplanner.core.api.score.buildin.bendablebigdecimal.BendableBigDecimalScore) ValueSource(org.junit.jupiter.params.provider.ValueSource) Logger(org.slf4j.Logger) Mockito.when(org.mockito.Mockito.when) Type(org.kie.kogito.explainability.model.Type) PredictionProvider(org.kie.kogito.explainability.model.PredictionProvider) Mockito.verify(org.mockito.Mockito.verify) ExecutionException(java.util.concurrent.ExecutionException) TimeUnit(java.util.concurrent.TimeUnit) Consumer(java.util.function.Consumer) ParameterizedTest(org.junit.jupiter.params.ParameterizedTest) TestUtils(org.kie.kogito.explainability.TestUtils) NumericalFeatureDomain(org.kie.kogito.explainability.model.domain.NumericalFeatureDomain) Config(org.kie.kogito.explainability.Config) Collections(java.util.Collections) FeatureDomain(org.kie.kogito.explainability.model.domain.FeatureDomain) Consumer(java.util.function.Consumer) ArrayList(java.util.ArrayList) BendableBigDecimalScore(org.optaplanner.core.api.score.buildin.bendablebigdecimal.BendableBigDecimalScore) ValueSource(org.junit.jupiter.params.provider.ValueSource) ParameterizedTest(org.junit.jupiter.params.ParameterizedTest)

Example 3 with BendableBigDecimalScore

use of org.optaplanner.core.api.score.buildin.bendablebigdecimal.BendableBigDecimalScore in project kogito-apps by kiegroup.

the class CounterFactualScoreCalculator method calculateScore.

/**
 * Calculates the counterfactual score for each proposed solution.
 * This method assumes that each model used as {@link org.kie.kogito.explainability.model.PredictionProvider} is
 * consistent, in the sense that for repeated operations, the size of the returned collection of
 * {@link PredictionOutput} is the same, if the size of {@link PredictionInput} doesn't change.
 *
 * @param solution Proposed solution
 * @return A {@link BendableBigDecimalScore} with three "hard" levels and one "soft" level
 */
@Override
public BendableBigDecimalScore calculateScore(CounterfactualSolution solution) {
    BendableBigDecimalScore currentScore = calculateInputScore(solution);
    final List<Feature> flattenedFeatures = solution.getEntities().stream().map(CounterfactualEntity::asFeature).collect(Collectors.toList());
    final List<Feature> input = CompositeFeatureUtils.unflattenFeatures(flattenedFeatures, solution.getOriginalFeatures());
    final List<PredictionInput> inputs = List.of(new PredictionInput(input));
    final CompletableFuture<List<PredictionOutput>> predictionAsync = solution.getModel().predictAsync(inputs);
    try {
        List<PredictionOutput> predictions = predictionAsync.get(Config.INSTANCE.getAsyncTimeout(), Config.INSTANCE.getAsyncTimeUnit());
        solution.setPredictionOutputs(predictions);
        final BendableBigDecimalScore outputScore = calculateOutputScore(solution);
        currentScore = currentScore.add(outputScore);
    } catch (ExecutionException e) {
        logger.error("Prediction returned an error {}", e.getMessage());
    } catch (InterruptedException e) {
        logger.error("Interrupted while waiting for prediction {}", e.getMessage());
        Thread.currentThread().interrupt();
    } catch (TimeoutException e) {
        logger.error("Timed out while waiting for prediction");
    }
    return currentScore;
}
Also used : PredictionInput(org.kie.kogito.explainability.model.PredictionInput) PredictionOutput(org.kie.kogito.explainability.model.PredictionOutput) BendableBigDecimalScore(org.optaplanner.core.api.score.buildin.bendablebigdecimal.BendableBigDecimalScore) List(java.util.List) ExecutionException(java.util.concurrent.ExecutionException) Feature(org.kie.kogito.explainability.model.Feature) TimeoutException(java.util.concurrent.TimeoutException)

Example 4 with BendableBigDecimalScore

use of org.optaplanner.core.api.score.buildin.bendablebigdecimal.BendableBigDecimalScore in project kogito-apps by kiegroup.

the class CounterfactualScoreCalculatorTest method testGoalSizeMatch.

/**
 * If the goal and the model's output is the same, the distances should all be zero.
 */
@Test
void testGoalSizeMatch() throws ExecutionException, InterruptedException {
    final CounterFactualScoreCalculator scoreCalculator = new CounterFactualScoreCalculator();
    PredictionProvider model = TestUtils.getFeatureSkipModel(0);
    List<Feature> features = new ArrayList<>();
    List<FeatureDomain> featureDomains = new ArrayList<>();
    List<Boolean> constraints = new ArrayList<>();
    // f-1
    features.add(FeatureFactory.newNumericalFeature("f-1", 1.0));
    featureDomains.add(NumericalFeatureDomain.create(0.0, 10.0));
    constraints.add(false);
    // f-2
    features.add(FeatureFactory.newNumericalFeature("f-2", 2.0));
    featureDomains.add(NumericalFeatureDomain.create(0.0, 10.0));
    constraints.add(false);
    // f-3
    features.add(FeatureFactory.newBooleanFeature("f-3", true));
    featureDomains.add(EmptyFeatureDomain.create());
    constraints.add(false);
    PredictionInput input = new PredictionInput(features);
    PredictionFeatureDomain domains = new PredictionFeatureDomain(featureDomains);
    List<CounterfactualEntity> entities = CounterfactualEntityFactory.createEntities(input);
    List<Output> goal = new ArrayList<>();
    goal.add(new Output("f-2", Type.NUMBER, new Value(2.0), 0.0));
    goal.add(new Output("f-3", Type.BOOLEAN, new Value(true), 0.0));
    final CounterfactualSolution solution = new CounterfactualSolution(entities, features, model, goal, UUID.randomUUID(), UUID.randomUUID(), 0.0);
    BendableBigDecimalScore score = scoreCalculator.calculateScore(solution);
    List<PredictionOutput> predictionOutputs = model.predictAsync(List.of(input)).get();
    assertTrue(score.isFeasible());
    assertEquals(2, goal.size());
    // A single prediction is expected
    assertEquals(1, predictionOutputs.size());
    // Single prediction with two features
    assertEquals(2, predictionOutputs.get(0).getOutputs().size());
    assertEquals(0, score.getHardScore(0).compareTo(BigDecimal.ZERO));
    assertEquals(0, score.getHardScore(1).compareTo(BigDecimal.ZERO));
    assertEquals(0, score.getHardScore(2).compareTo(BigDecimal.ZERO));
    assertEquals(0, score.getSoftScore(0).compareTo(BigDecimal.ZERO));
    assertEquals(0, score.getSoftScore(1).compareTo(BigDecimal.ZERO));
    assertEquals(3, score.getHardLevelsSize());
    assertEquals(2, score.getSoftLevelsSize());
}
Also used : PredictionInput(org.kie.kogito.explainability.model.PredictionInput) ArrayList(java.util.ArrayList) BendableBigDecimalScore(org.optaplanner.core.api.score.buildin.bendablebigdecimal.BendableBigDecimalScore) EmptyFeatureDomain(org.kie.kogito.explainability.model.domain.EmptyFeatureDomain) PredictionFeatureDomain(org.kie.kogito.explainability.model.PredictionFeatureDomain) NumericalFeatureDomain(org.kie.kogito.explainability.model.domain.NumericalFeatureDomain) FeatureDomain(org.kie.kogito.explainability.model.domain.FeatureDomain) PredictionProvider(org.kie.kogito.explainability.model.PredictionProvider) Feature(org.kie.kogito.explainability.model.Feature) CounterfactualEntity(org.kie.kogito.explainability.local.counterfactual.entities.CounterfactualEntity) PredictionFeatureDomain(org.kie.kogito.explainability.model.PredictionFeatureDomain) PredictionOutput(org.kie.kogito.explainability.model.PredictionOutput) PredictionOutput(org.kie.kogito.explainability.model.PredictionOutput) Output(org.kie.kogito.explainability.model.Output) Value(org.kie.kogito.explainability.model.Value) Test(org.junit.jupiter.api.Test) ParameterizedTest(org.junit.jupiter.params.ParameterizedTest)

Example 5 with BendableBigDecimalScore

use of org.optaplanner.core.api.score.buildin.bendablebigdecimal.BendableBigDecimalScore in project kogito-apps by kiegroup.

the class CounterfactualScoreCalculatorTest method testNullBooleanInput.

/**
 * Null values for input Boolean features should be accepted as valid
 */
@Test
void testNullBooleanInput() throws ExecutionException, InterruptedException {
    final CounterFactualScoreCalculator scoreCalculator = new CounterFactualScoreCalculator();
    PredictionProvider model = TestUtils.getFeatureSkipModel(0);
    List<Feature> features = new ArrayList<>();
    List<FeatureDomain> featureDomains = new ArrayList<>();
    List<Boolean> constraints = new ArrayList<>();
    // f-1
    features.add(FeatureFactory.newNumericalFeature("f-1", 1.0));
    featureDomains.add(NumericalFeatureDomain.create(0.0, 10.0));
    constraints.add(false);
    // f-2
    features.add(FeatureFactory.newBooleanFeature("f-2", null));
    featureDomains.add(EmptyFeatureDomain.create());
    constraints.add(false);
    // f-3
    features.add(FeatureFactory.newBooleanFeature("f-3", true));
    featureDomains.add(EmptyFeatureDomain.create());
    constraints.add(false);
    PredictionInput input = new PredictionInput(features);
    PredictionFeatureDomain domains = new PredictionFeatureDomain(featureDomains);
    List<CounterfactualEntity> entities = CounterfactualEntityFactory.createEntities(input);
    List<Output> goal = new ArrayList<>();
    goal.add(new Output("f-2", Type.BOOLEAN, new Value(null), 0.0));
    goal.add(new Output("f-3", Type.BOOLEAN, new Value(true), 0.0));
    final CounterfactualSolution solution = new CounterfactualSolution(entities, features, model, goal, UUID.randomUUID(), UUID.randomUUID(), 0.0);
    BendableBigDecimalScore score = scoreCalculator.calculateScore(solution);
    List<PredictionOutput> predictionOutputs = model.predictAsync(List.of(input)).get();
    assertTrue(score.isFeasible());
    assertEquals(2, goal.size());
    // A single prediction is expected
    assertEquals(1, predictionOutputs.size());
    // Single prediction with two features
    assertEquals(2, predictionOutputs.get(0).getOutputs().size());
    assertEquals(0, score.getHardScore(0).compareTo(BigDecimal.ZERO));
    assertEquals(0, score.getHardScore(1).compareTo(BigDecimal.ZERO));
    assertEquals(0, score.getHardScore(2).compareTo(BigDecimal.ZERO));
    assertEquals(0, score.getSoftScore(0).compareTo(BigDecimal.ZERO));
    assertEquals(0, score.getSoftScore(1).compareTo(BigDecimal.ZERO));
    assertEquals(3, score.getHardLevelsSize());
    assertEquals(2, score.getSoftLevelsSize());
}
Also used : PredictionInput(org.kie.kogito.explainability.model.PredictionInput) ArrayList(java.util.ArrayList) BendableBigDecimalScore(org.optaplanner.core.api.score.buildin.bendablebigdecimal.BendableBigDecimalScore) EmptyFeatureDomain(org.kie.kogito.explainability.model.domain.EmptyFeatureDomain) PredictionFeatureDomain(org.kie.kogito.explainability.model.PredictionFeatureDomain) NumericalFeatureDomain(org.kie.kogito.explainability.model.domain.NumericalFeatureDomain) FeatureDomain(org.kie.kogito.explainability.model.domain.FeatureDomain) PredictionProvider(org.kie.kogito.explainability.model.PredictionProvider) Feature(org.kie.kogito.explainability.model.Feature) CounterfactualEntity(org.kie.kogito.explainability.local.counterfactual.entities.CounterfactualEntity) PredictionFeatureDomain(org.kie.kogito.explainability.model.PredictionFeatureDomain) PredictionOutput(org.kie.kogito.explainability.model.PredictionOutput) PredictionOutput(org.kie.kogito.explainability.model.PredictionOutput) Output(org.kie.kogito.explainability.model.Output) Value(org.kie.kogito.explainability.model.Value) Test(org.junit.jupiter.api.Test) ParameterizedTest(org.junit.jupiter.params.ParameterizedTest)

Aggregations

PredictionInput (org.kie.kogito.explainability.model.PredictionInput)6 PredictionOutput (org.kie.kogito.explainability.model.PredictionOutput)6 BendableBigDecimalScore (org.optaplanner.core.api.score.buildin.bendablebigdecimal.BendableBigDecimalScore)6 Feature (org.kie.kogito.explainability.model.Feature)5 ArrayList (java.util.ArrayList)4 ParameterizedTest (org.junit.jupiter.params.ParameterizedTest)4 CounterfactualEntity (org.kie.kogito.explainability.local.counterfactual.entities.CounterfactualEntity)4 Output (org.kie.kogito.explainability.model.Output)4 PredictionProvider (org.kie.kogito.explainability.model.PredictionProvider)4 EmptyFeatureDomain (org.kie.kogito.explainability.model.domain.EmptyFeatureDomain)4 FeatureDomain (org.kie.kogito.explainability.model.domain.FeatureDomain)4 NumericalFeatureDomain (org.kie.kogito.explainability.model.domain.NumericalFeatureDomain)4 Test (org.junit.jupiter.api.Test)3 Value (org.kie.kogito.explainability.model.Value)3 List (java.util.List)2 Random (java.util.Random)2 UUID (java.util.UUID)2 ExecutionException (java.util.concurrent.ExecutionException)2 TimeoutException (java.util.concurrent.TimeoutException)2 ValueSource (org.junit.jupiter.params.provider.ValueSource)2