use of org.kie.kogito.explainability.local.counterfactual.entities.CounterfactualEntity in project kogito-apps by kiegroup.
the class CounterfactualExplainerTest method testCounterfactualCategoricalNotStrict.
/**
* Search for a counterfactual using categorical features with the Symbolic arithmetic model.
* The outcome match is not strict (goal threshold of 0.01).
* The CF should be valid with this number of iterations.
*
* @param seed
* @throws ExecutionException
* @throws InterruptedException
* @throws TimeoutException
*/
@ParameterizedTest
@ValueSource(ints = { 0, 1, 2 })
void testCounterfactualCategoricalNotStrict(int seed) throws ExecutionException, InterruptedException, TimeoutException {
Random random = new Random();
random.setSeed(seed);
final List<Output> goal = List.of(new Output("result", Type.NUMBER, new Value(25.0), 0.0d));
List<Feature> features = new LinkedList<>();
features.add(FeatureFactory.newNumericalFeature("x-1", 5.0, NumericalFeatureDomain.create(0.0, 100.0)));
features.add(FeatureFactory.newNumericalFeature("x-2", 40.0, NumericalFeatureDomain.create(0.0, 100.0)));
features.add(FeatureFactory.newCategoricalFeature("operand", "*", CategoricalFeatureDomain.create("+", "-", "/", "*")));
final CounterfactualResult result = runCounterfactualSearch((long) seed, goal, features, TestUtils.getSymbolicArithmeticModel(), 0.01);
final List<CounterfactualEntity> counterfactualEntities = result.getEntities();
Stream<Feature> counterfactualFeatures = counterfactualEntities.stream().map(CounterfactualEntity::asFeature);
String operand = counterfactualFeatures.filter(feature -> feature.getName().equals("operand")).findFirst().get().getValue().asString();
List<Feature> numericalFeatures = counterfactualEntities.stream().map(CounterfactualEntity::asFeature).filter(feature -> !feature.getName().equals("operand")).collect(Collectors.toList());
double opResult = 0.0;
for (Feature feature : numericalFeatures) {
switch(operand) {
case "+":
opResult += feature.getValue().asNumber();
break;
case "-":
opResult -= feature.getValue().asNumber();
break;
case "*":
opResult *= feature.getValue().asNumber();
break;
case "/":
opResult /= feature.getValue().asNumber();
break;
}
}
final double epsilon = 0.5;
assertTrue(result.isValid());
assertTrue(opResult <= 25.0 + epsilon);
assertTrue(opResult >= 25.0 - epsilon);
}
use of org.kie.kogito.explainability.local.counterfactual.entities.CounterfactualEntity in project kogito-apps by kiegroup.
the class CounterfactualScoreCalculatorTest method testGoalSizeSmaller.
/**
* Using a smaller number of features in the goals (1) than the model's output (2) should
* throw an {@link IllegalArgumentException} with the appropriate message.
*/
@Test
void testGoalSizeSmaller() throws ExecutionException, InterruptedException {
final CounterFactualScoreCalculator scoreCalculator = new CounterFactualScoreCalculator();
PredictionProvider model = TestUtils.getFeatureSkipModel(0);
List<Feature> features = new ArrayList<>();
List<FeatureDomain> featureDomains = new ArrayList<>();
List<Boolean> constraints = new ArrayList<>();
// f-1
features.add(FeatureFactory.newNumericalFeature("f-1", 1.0));
featureDomains.add(NumericalFeatureDomain.create(0.0, 10.0));
constraints.add(false);
// f-2
features.add(FeatureFactory.newNumericalFeature("f-2", 2.0));
featureDomains.add(NumericalFeatureDomain.create(0.0, 10.0));
constraints.add(false);
// f-3
features.add(FeatureFactory.newBooleanFeature("f-3", true));
featureDomains.add(EmptyFeatureDomain.create());
constraints.add(false);
PredictionInput input = new PredictionInput(features);
PredictionFeatureDomain domains = new PredictionFeatureDomain(featureDomains);
List<CounterfactualEntity> entities = CounterfactualEntityFactory.createEntities(input);
List<Output> goal = new ArrayList<>();
goal.add(new Output("f-2", Type.NUMBER, new Value(2.0), 0.0));
List<PredictionOutput> predictionOutputs = model.predictAsync(List.of(input)).get();
assertEquals(1, goal.size());
// A single prediction is expected
assertEquals(1, predictionOutputs.size());
// Single prediction with two features
assertEquals(2, predictionOutputs.get(0).getOutputs().size());
final CounterfactualSolution solution = new CounterfactualSolution(entities, features, model, goal, UUID.randomUUID(), UUID.randomUUID(), 0.0);
IllegalArgumentException exception = assertThrows(IllegalArgumentException.class, () -> {
scoreCalculator.calculateScore(solution);
});
assertEquals("Prediction size must be equal to goal size", exception.getMessage());
}
use of org.kie.kogito.explainability.local.counterfactual.entities.CounterfactualEntity in project kogito-apps by kiegroup.
the class CounterFactualScoreCalculator method calculateInputScore.
private BendableBigDecimalScore calculateInputScore(CounterfactualSolution solution) {
StringBuilder builder = new StringBuilder();
int secondarySoftScore = 0;
int secondaryHardScore = 0;
// Calculate similarities between original inputs and proposed inputs
double inputSimilarities = 0.0;
final int numberOfEntities = solution.getEntities().size();
for (CounterfactualEntity entity : solution.getEntities()) {
final double entitySimilarity = entity.similarity();
inputSimilarities += entitySimilarity / numberOfEntities;
final Feature f = entity.asFeature();
builder.append(String.format("%s=%s (d:%f)", f.getName(), f.getValue().getUnderlyingObject(), entitySimilarity));
if (entity.isChanged()) {
secondarySoftScore -= 1;
if (entity.isConstrained()) {
secondaryHardScore -= 1;
}
}
}
logger.debug("Current solution: {}", builder);
// Calculate Gower distance from the similarities
final double primarySoftScore = -Math.sqrt(Math.abs(1.0 - inputSimilarities));
logger.debug("Changed constraints penalty: {}", secondaryHardScore);
logger.debug("Feature distance: {}", -Math.abs(primarySoftScore));
return BendableBigDecimalScore.of(new BigDecimal[] { BigDecimal.ZERO, BigDecimal.valueOf(secondaryHardScore), BigDecimal.ZERO }, new BigDecimal[] { BigDecimal.valueOf(-Math.abs(primarySoftScore)), BigDecimal.valueOf(secondarySoftScore) });
}
use of org.kie.kogito.explainability.local.counterfactual.entities.CounterfactualEntity in project kogito-apps by kiegroup.
the class CounterfactualExplainer method explainAsync.
@Override
public CompletableFuture<CounterfactualResult> explainAsync(Prediction prediction, PredictionProvider model, Consumer<CounterfactualResult> intermediateResultsConsumer) {
final AtomicLong sequenceId = new AtomicLong(0);
final CounterfactualPrediction cfPrediction = (CounterfactualPrediction) prediction;
final UUID executionId = cfPrediction.getExecutionId();
final Long maxRunningTimeSeconds = cfPrediction.getMaxRunningTimeSeconds();
final List<CounterfactualEntity> entities = CounterfactualEntityFactory.createEntities(prediction.getInput());
final List<Output> goal = prediction.getOutput().getOutputs();
// Original features kept as structural reference to re-assemble composite features
final List<Feature> originalFeatures = prediction.getInput().getFeatures();
Function<UUID, CounterfactualSolution> initial = uuid -> new CounterfactualSolution(entities, originalFeatures, model, goal, UUID.randomUUID(), executionId, this.counterfactualConfig.getGoalThreshold());
final CompletableFuture<CounterfactualSolution> cfSolution = CompletableFuture.supplyAsync(() -> {
SolverConfig solverConfig = this.counterfactualConfig.getSolverConfig();
if (Objects.nonNull(maxRunningTimeSeconds)) {
solverConfig.withTerminationSpentLimit(Duration.ofSeconds(maxRunningTimeSeconds));
}
try (SolverManager<CounterfactualSolution, UUID> solverManager = this.counterfactualConfig.getSolverManagerFactory().apply(solverConfig)) {
SolverJob<CounterfactualSolution, UUID> solverJob = solverManager.solveAndListen(executionId, initial, assignSolutionId.andThen(createSolutionConsumer(intermediateResultsConsumer, sequenceId)), null);
try {
// Wait until the solving ends
return solverJob.getFinalBestSolution();
} catch (ExecutionException e) {
logger.error("Solving failed: {}", e.getMessage());
throw new IllegalStateException("Prediction returned an error", e);
} catch (InterruptedException e) {
Thread.currentThread().interrupt();
throw new IllegalStateException("Solving failed (Thread interrupted)", e);
}
}
}, this.counterfactualConfig.getExecutor());
final CompletableFuture<List<PredictionOutput>> cfOutputs = cfSolution.thenCompose(s -> model.predictAsync(buildInput(s.getEntities())));
return CompletableFuture.allOf(cfOutputs, cfSolution).thenApply(v -> {
CounterfactualSolution solution = cfSolution.join();
return new CounterfactualResult(solution.getEntities(), solution.getOriginalFeatures(), cfOutputs.join(), solution.getScore().isFeasible(), UUID.randomUUID(), solution.getExecutionId(), sequenceId.incrementAndGet());
});
}
use of org.kie.kogito.explainability.local.counterfactual.entities.CounterfactualEntity in project kogito-apps by kiegroup.
the class CounterfactualExplainerServiceHandlerTest method testCreateIntermediateResult.
@Test
public void testCreateIntermediateResult() {
CounterfactualExplainabilityRequest request = new CounterfactualExplainabilityRequest(EXECUTION_ID, SERVICE_URL, MODEL_IDENTIFIER, COUNTERFACTUAL_ID, Collections.emptyList(), Collections.emptyList(), Collections.emptyList(), MAX_RUNNING_TIME_SECONDS);
List<CounterfactualEntity> entities = List.of(DoubleEntity.from(new Feature("input1", Type.NUMBER, new Value(123.0d)), 0, 1000));
CounterfactualResult counterfactuals = new CounterfactualResult(entities, entities.stream().map(CounterfactualEntity::asFeature).collect(Collectors.toList()), List.of(new PredictionOutput(List.of(new Output("output1", Type.NUMBER, new Value(555.0d), 1.0)))), true, UUID.fromString(SOLUTION_ID), UUID.fromString(EXECUTION_ID), 0);
BaseExplainabilityResult base = handler.createIntermediateResult(request, counterfactuals);
assertTrue(base instanceof CounterfactualExplainabilityResult);
CounterfactualExplainabilityResult result = (CounterfactualExplainabilityResult) base;
assertEquals(ExplainabilityStatus.SUCCEEDED, result.getStatus());
assertEquals(CounterfactualExplainabilityResult.Stage.INTERMEDIATE, result.getStage());
assertEquals(EXECUTION_ID, result.getExecutionId());
assertEquals(COUNTERFACTUAL_ID, result.getCounterfactualId());
assertEquals(1, result.getInputs().size());
assertTrue(result.getInputs().stream().anyMatch(i -> i.getName().equals("input1")));
NamedTypedValue input1 = result.getInputs().iterator().next();
assertEquals(Double.class.getSimpleName(), input1.getValue().getType());
assertEquals(TypedValue.Kind.UNIT, input1.getValue().getKind());
assertEquals(123.0, input1.getValue().toUnit().getValue().asDouble());
assertEquals(1, result.getOutputs().size());
assertTrue(result.getOutputs().stream().anyMatch(o -> o.getName().equals("output1")));
NamedTypedValue output1 = result.getOutputs().iterator().next();
assertEquals(Double.class.getSimpleName(), output1.getValue().getType());
assertEquals(TypedValue.Kind.UNIT, output1.getValue().getKind());
assertEquals(555.0, output1.getValue().toUnit().getValue().asDouble());
}
Aggregations