use of org.kie.kogito.explainability.model.PredictionInput in project kogito-apps by kiegroup.
the class PmmlRegressionCategoricalLimeExplainerTest method testExplanationStabilityWithOptimization.
@Disabled("See KOGITO-6154")
@Test
void testExplanationStabilityWithOptimization() throws ExecutionException, InterruptedException, TimeoutException {
PredictionProvider model = getModel();
List<PredictionInput> samples = getSamples();
List<PredictionOutput> predictionOutputs = model.predictAsync(samples.subList(0, 5)).get();
List<Prediction> predictions = DataUtils.getPredictions(samples, predictionOutputs);
long seed = 0;
LimeConfigOptimizer limeConfigOptimizer = new LimeConfigOptimizer().withDeterministicExecution(true);
Random random = new Random();
PerturbationContext perturbationContext = new PerturbationContext(seed, random, 1);
LimeConfig initialConfig = new LimeConfig().withSamples(10).withPerturbationContext(perturbationContext);
LimeConfig optimizedConfig = limeConfigOptimizer.optimize(initialConfig, predictions, model);
assertThat(optimizedConfig).isNotSameAs(initialConfig);
LimeExplainer limeExplainer = new LimeExplainer(optimizedConfig);
PredictionInput testPredictionInput = getTestInput();
List<PredictionOutput> testPredictionOutputs = model.predictAsync(List.of(testPredictionInput)).get(Config.INSTANCE.getAsyncTimeout(), Config.INSTANCE.getAsyncTimeUnit());
Prediction instance = new SimplePrediction(testPredictionInput, testPredictionOutputs.get(0));
assertDoesNotThrow(() -> ValidationUtils.validateLocalSaliencyStability(model, instance, limeExplainer, 1, 0.6, 0.6));
}
use of org.kie.kogito.explainability.model.PredictionInput in project kogito-apps by kiegroup.
the class CounterfactualScoreCalculatorTest method testGoalSizeMatch.
/**
* If the goal and the model's output is the same, the distances should all be zero.
*/
@Test
void testGoalSizeMatch() throws ExecutionException, InterruptedException {
final CounterFactualScoreCalculator scoreCalculator = new CounterFactualScoreCalculator();
PredictionProvider model = TestUtils.getFeatureSkipModel(0);
List<Feature> features = new ArrayList<>();
List<FeatureDomain> featureDomains = new ArrayList<>();
List<Boolean> constraints = new ArrayList<>();
// f-1
features.add(FeatureFactory.newNumericalFeature("f-1", 1.0));
featureDomains.add(NumericalFeatureDomain.create(0.0, 10.0));
constraints.add(false);
// f-2
features.add(FeatureFactory.newNumericalFeature("f-2", 2.0));
featureDomains.add(NumericalFeatureDomain.create(0.0, 10.0));
constraints.add(false);
// f-3
features.add(FeatureFactory.newBooleanFeature("f-3", true));
featureDomains.add(EmptyFeatureDomain.create());
constraints.add(false);
PredictionInput input = new PredictionInput(features);
PredictionFeatureDomain domains = new PredictionFeatureDomain(featureDomains);
List<CounterfactualEntity> entities = CounterfactualEntityFactory.createEntities(input);
List<Output> goal = new ArrayList<>();
goal.add(new Output("f-2", Type.NUMBER, new Value(2.0), 0.0));
goal.add(new Output("f-3", Type.BOOLEAN, new Value(true), 0.0));
final CounterfactualSolution solution = new CounterfactualSolution(entities, features, model, goal, UUID.randomUUID(), UUID.randomUUID(), 0.0);
BendableBigDecimalScore score = scoreCalculator.calculateScore(solution);
List<PredictionOutput> predictionOutputs = model.predictAsync(List.of(input)).get();
assertTrue(score.isFeasible());
assertEquals(2, goal.size());
// A single prediction is expected
assertEquals(1, predictionOutputs.size());
// Single prediction with two features
assertEquals(2, predictionOutputs.get(0).getOutputs().size());
assertEquals(0, score.getHardScore(0).compareTo(BigDecimal.ZERO));
assertEquals(0, score.getHardScore(1).compareTo(BigDecimal.ZERO));
assertEquals(0, score.getHardScore(2).compareTo(BigDecimal.ZERO));
assertEquals(0, score.getSoftScore(0).compareTo(BigDecimal.ZERO));
assertEquals(0, score.getSoftScore(1).compareTo(BigDecimal.ZERO));
assertEquals(3, score.getHardLevelsSize());
assertEquals(2, score.getSoftLevelsSize());
}
use of org.kie.kogito.explainability.model.PredictionInput in project kogito-apps by kiegroup.
the class DatasetEncoderTest method testDatasetEncodingWithBooleanData.
@Test
void testDatasetEncodingWithBooleanData() {
List<PredictionInput> perturbedInputs = new LinkedList<>();
for (int i = 0; i < 10; i++) {
List<Feature> inputFeatures = new LinkedList<>();
for (int j = 0; j < 3; j++) {
inputFeatures.add(TestUtils.getMockedFeature(Type.BOOLEAN, new Value(j % 2 == 0)));
}
perturbedInputs.add(new PredictionInput(inputFeatures));
}
List<Feature> features = new LinkedList<>();
for (int i = 0; i < 3; i++) {
features.add(TestUtils.getMockedFeature(Type.BOOLEAN, new Value(i % 2 == 0)));
}
PredictionInput originalInput = new PredictionInput(features);
assertEncode(perturbedInputs, originalInput);
}
use of org.kie.kogito.explainability.model.PredictionInput in project kogito-apps by kiegroup.
the class DatasetEncoderTest method testEmptyDatasetEncoding.
@Test
void testEmptyDatasetEncoding() {
List<PredictionInput> inputs = new LinkedList<>();
List<Output> outputs = new LinkedList<>();
List<Feature> features = new LinkedList<>();
Output originalOutput = new Output("foo", Type.NUMBER, new Value(1), 1d);
EncodingParams params = new EncodingParams(1, 0.1);
DatasetEncoder datasetEncoder = new DatasetEncoder(inputs, outputs, features, originalOutput, params);
Collection<Pair<double[], Double>> trainingSet = datasetEncoder.getEncodedTrainingSet();
assertNotNull(trainingSet);
assertTrue(trainingSet.isEmpty());
}
use of org.kie.kogito.explainability.model.PredictionInput in project kogito-apps by kiegroup.
the class DatasetEncoderTest method testDatasetEncodingWithBinaryData.
@Test
void testDatasetEncodingWithBinaryData() {
List<PredictionInput> perturbedInputs = new LinkedList<>();
for (int i = 0; i < 10; i++) {
List<Feature> inputFeatures = new LinkedList<>();
for (int j = 0; j < 3; j++) {
ByteBuffer byteBuffer = ByteBuffer.wrap((i + "" + j).getBytes(Charset.defaultCharset()));
inputFeatures.add(TestUtils.getMockedFeature(Type.BINARY, new Value(byteBuffer)));
}
perturbedInputs.add(new PredictionInput(inputFeatures));
}
List<Feature> features = new LinkedList<>();
for (int i = 0; i < 3; i++) {
ByteBuffer byteBuffer = ByteBuffer.wrap((i + "" + i).getBytes(Charset.defaultCharset()));
features.add(TestUtils.getMockedFeature(Type.BINARY, new Value(byteBuffer)));
}
PredictionInput originalInput = new PredictionInput(features);
assertEncode(perturbedInputs, originalInput);
}
Aggregations