Search in sources :

Example 91 with PredictionOutput

use of org.kie.kogito.explainability.model.PredictionOutput in project kogito-apps by kiegroup.

the class TestUtils method getDummyTextClassifier.

public static PredictionProvider getDummyTextClassifier() {
    List<String> blackList = Arrays.asList("money", "$", "£", "bitcoin");
    return inputs -> supplyAsync(() -> {
        List<PredictionOutput> outputs = new LinkedList<>();
        for (PredictionInput input : inputs) {
            boolean spam = false;
            for (Feature f : input.getFeatures()) {
                if (!spam) {
                    String s = f.getValue().asString();
                    String[] words = s.split(" ");
                    for (String w : words) {
                        if (blackList.contains(w)) {
                            spam = true;
                            break;
                        }
                    }
                }
            }
            Output output = new Output("spam", Type.BOOLEAN, new Value(spam), 1d);
            outputs.add(new PredictionOutput(List.of(output)));
        }
        return outputs;
    });
}
Also used : Arrays(java.util.Arrays) LimeExplainer(org.kie.kogito.explainability.local.lime.LimeExplainer) Feature(org.kie.kogito.explainability.model.Feature) Prediction(org.kie.kogito.explainability.model.Prediction) Random(java.util.Random) Mockito.when(org.mockito.Mockito.when) Value(org.kie.kogito.explainability.model.Value) Type(org.kie.kogito.explainability.model.Type) PredictionProvider(org.kie.kogito.explainability.model.PredictionProvider) ArrayList(java.util.ArrayList) PredictionInput(org.kie.kogito.explainability.model.PredictionInput) List(java.util.List) Pair(org.apache.commons.lang3.tuple.Pair) Output(org.kie.kogito.explainability.model.Output) CompletableFuture.supplyAsync(java.util.concurrent.CompletableFuture.supplyAsync) Optional(java.util.Optional) ValidationUtils(org.kie.kogito.explainability.utils.ValidationUtils) LinkedList(java.util.LinkedList) Assertions.assertDoesNotThrow(org.junit.jupiter.api.Assertions.assertDoesNotThrow) PredictionOutput(org.kie.kogito.explainability.model.PredictionOutput) Mockito.mock(org.mockito.Mockito.mock) PredictionInput(org.kie.kogito.explainability.model.PredictionInput) PredictionOutput(org.kie.kogito.explainability.model.PredictionOutput) Output(org.kie.kogito.explainability.model.Output) PredictionOutput(org.kie.kogito.explainability.model.PredictionOutput) Value(org.kie.kogito.explainability.model.Value) Feature(org.kie.kogito.explainability.model.Feature) LinkedList(java.util.LinkedList)

Example 92 with PredictionOutput

use of org.kie.kogito.explainability.model.PredictionOutput in project kogito-apps by kiegroup.

the class TestUtils method getFeaturePassModel.

public static PredictionProvider getFeaturePassModel(int featureIndex) {
    return inputs -> supplyAsync(() -> {
        List<PredictionOutput> predictionOutputs = new LinkedList<>();
        for (PredictionInput predictionInput : inputs) {
            List<Feature> features = predictionInput.getFeatures();
            Feature feature = features.get(featureIndex);
            PredictionOutput predictionOutput = new PredictionOutput(List.of(new Output("feature-" + featureIndex, feature.getType(), feature.getValue(), 1d)));
            predictionOutputs.add(predictionOutput);
        }
        return predictionOutputs;
    });
}
Also used : Arrays(java.util.Arrays) LimeExplainer(org.kie.kogito.explainability.local.lime.LimeExplainer) Feature(org.kie.kogito.explainability.model.Feature) Prediction(org.kie.kogito.explainability.model.Prediction) Random(java.util.Random) Mockito.when(org.mockito.Mockito.when) Value(org.kie.kogito.explainability.model.Value) Type(org.kie.kogito.explainability.model.Type) PredictionProvider(org.kie.kogito.explainability.model.PredictionProvider) ArrayList(java.util.ArrayList) PredictionInput(org.kie.kogito.explainability.model.PredictionInput) List(java.util.List) Pair(org.apache.commons.lang3.tuple.Pair) Output(org.kie.kogito.explainability.model.Output) CompletableFuture.supplyAsync(java.util.concurrent.CompletableFuture.supplyAsync) Optional(java.util.Optional) ValidationUtils(org.kie.kogito.explainability.utils.ValidationUtils) LinkedList(java.util.LinkedList) Assertions.assertDoesNotThrow(org.junit.jupiter.api.Assertions.assertDoesNotThrow) PredictionOutput(org.kie.kogito.explainability.model.PredictionOutput) Mockito.mock(org.mockito.Mockito.mock) PredictionInput(org.kie.kogito.explainability.model.PredictionInput) PredictionOutput(org.kie.kogito.explainability.model.PredictionOutput) Output(org.kie.kogito.explainability.model.Output) PredictionOutput(org.kie.kogito.explainability.model.PredictionOutput) Feature(org.kie.kogito.explainability.model.Feature) LinkedList(java.util.LinkedList)

Example 93 with PredictionOutput

use of org.kie.kogito.explainability.model.PredictionOutput in project kogito-apps by kiegroup.

the class TestUtils method getFeatureSkipModel.

/**
 * Test model which returns the inputs as outputs, except for a single specified feature
 *
 * @param featureIndex Index of the input feature to omit from output
 * @return A {@link PredictionProvider} model
 */
public static PredictionProvider getFeatureSkipModel(int featureIndex) {
    return inputs -> supplyAsync(() -> {
        List<PredictionOutput> predictionOutputs = new LinkedList<>();
        for (PredictionInput predictionInput : inputs) {
            List<Feature> features = predictionInput.getFeatures();
            List<Output> outputs = new ArrayList<>();
            for (int i = 0; i < features.size(); i++) {
                if (i != featureIndex) {
                    Feature feature = features.get(i);
                    outputs.add(new Output(feature.getName(), feature.getType(), feature.getValue(), 1.0));
                }
            }
            PredictionOutput predictionOutput = new PredictionOutput(outputs);
            predictionOutputs.add(predictionOutput);
        }
        return predictionOutputs;
    });
}
Also used : Arrays(java.util.Arrays) LimeExplainer(org.kie.kogito.explainability.local.lime.LimeExplainer) Feature(org.kie.kogito.explainability.model.Feature) Prediction(org.kie.kogito.explainability.model.Prediction) Random(java.util.Random) Mockito.when(org.mockito.Mockito.when) Value(org.kie.kogito.explainability.model.Value) Type(org.kie.kogito.explainability.model.Type) PredictionProvider(org.kie.kogito.explainability.model.PredictionProvider) ArrayList(java.util.ArrayList) PredictionInput(org.kie.kogito.explainability.model.PredictionInput) List(java.util.List) Pair(org.apache.commons.lang3.tuple.Pair) Output(org.kie.kogito.explainability.model.Output) CompletableFuture.supplyAsync(java.util.concurrent.CompletableFuture.supplyAsync) Optional(java.util.Optional) ValidationUtils(org.kie.kogito.explainability.utils.ValidationUtils) LinkedList(java.util.LinkedList) Assertions.assertDoesNotThrow(org.junit.jupiter.api.Assertions.assertDoesNotThrow) PredictionOutput(org.kie.kogito.explainability.model.PredictionOutput) Mockito.mock(org.mockito.Mockito.mock) PredictionInput(org.kie.kogito.explainability.model.PredictionInput) PredictionOutput(org.kie.kogito.explainability.model.PredictionOutput) Output(org.kie.kogito.explainability.model.Output) PredictionOutput(org.kie.kogito.explainability.model.PredictionOutput) ArrayList(java.util.ArrayList) Feature(org.kie.kogito.explainability.model.Feature) LinkedList(java.util.LinkedList)

Example 94 with PredictionOutput

use of org.kie.kogito.explainability.model.PredictionOutput in project kogito-apps by kiegroup.

the class TestUtils method getEvenFeatureModel.

public static PredictionProvider getEvenFeatureModel(int featureIndex) {
    return inputs -> supplyAsync(() -> {
        List<PredictionOutput> predictionOutputs = new LinkedList<>();
        for (PredictionInput predictionInput : inputs) {
            List<Feature> features = predictionInput.getFeatures();
            Feature feature = features.get(featureIndex);
            double v = feature.getValue().asNumber();
            PredictionOutput predictionOutput = new PredictionOutput(List.of(new Output("feature-" + featureIndex, Type.BOOLEAN, new Value(v % 2 == 0), 1d)));
            predictionOutputs.add(predictionOutput);
        }
        return predictionOutputs;
    });
}
Also used : Arrays(java.util.Arrays) LimeExplainer(org.kie.kogito.explainability.local.lime.LimeExplainer) Feature(org.kie.kogito.explainability.model.Feature) Prediction(org.kie.kogito.explainability.model.Prediction) Random(java.util.Random) Mockito.when(org.mockito.Mockito.when) Value(org.kie.kogito.explainability.model.Value) Type(org.kie.kogito.explainability.model.Type) PredictionProvider(org.kie.kogito.explainability.model.PredictionProvider) ArrayList(java.util.ArrayList) PredictionInput(org.kie.kogito.explainability.model.PredictionInput) List(java.util.List) Pair(org.apache.commons.lang3.tuple.Pair) Output(org.kie.kogito.explainability.model.Output) CompletableFuture.supplyAsync(java.util.concurrent.CompletableFuture.supplyAsync) Optional(java.util.Optional) ValidationUtils(org.kie.kogito.explainability.utils.ValidationUtils) LinkedList(java.util.LinkedList) Assertions.assertDoesNotThrow(org.junit.jupiter.api.Assertions.assertDoesNotThrow) PredictionOutput(org.kie.kogito.explainability.model.PredictionOutput) Mockito.mock(org.mockito.Mockito.mock) PredictionInput(org.kie.kogito.explainability.model.PredictionInput) PredictionOutput(org.kie.kogito.explainability.model.PredictionOutput) Output(org.kie.kogito.explainability.model.Output) PredictionOutput(org.kie.kogito.explainability.model.PredictionOutput) Value(org.kie.kogito.explainability.model.Value) Feature(org.kie.kogito.explainability.model.Feature) LinkedList(java.util.LinkedList)

Example 95 with PredictionOutput

use of org.kie.kogito.explainability.model.PredictionOutput in project kogito-apps by kiegroup.

the class AggregatedLimeExplainerTest method testExplainWithPredictions.

@ParameterizedTest
@ValueSource(ints = { 0, 1, 2, 3, 4 })
void testExplainWithPredictions(int seed) throws ExecutionException, InterruptedException {
    Random random = new Random();
    random.setSeed(seed);
    PredictionProvider sumSkipModel = TestUtils.getSumSkipModel(1);
    DataDistribution dataDistribution = DataUtils.generateRandomDataDistribution(3, 100, random);
    List<PredictionInput> samples = dataDistribution.sample(10);
    List<PredictionOutput> predictionOutputs = sumSkipModel.predictAsync(samples).get();
    List<Prediction> predictions = DataUtils.getPredictions(samples, predictionOutputs);
    AggregatedLimeExplainer aggregatedLimeExplainer = new AggregatedLimeExplainer();
    Map<String, Saliency> explain = aggregatedLimeExplainer.explainFromPredictions(sumSkipModel, predictions).get();
    assertNotNull(explain);
    assertEquals(1, explain.size());
    assertTrue(explain.containsKey("sum-but1"));
    Saliency saliency = explain.get("sum-but1");
    assertNotNull(saliency);
    List<String> collect = saliency.getPositiveFeatures(2).stream().map(FeatureImportance::getFeature).map(Feature::getName).collect(Collectors.toList());
    // skipped feature should not appear in top two positive features
    assertFalse(collect.contains("f1"));
}
Also used : PredictionInput(org.kie.kogito.explainability.model.PredictionInput) Prediction(org.kie.kogito.explainability.model.Prediction) Saliency(org.kie.kogito.explainability.model.Saliency) PredictionProvider(org.kie.kogito.explainability.model.PredictionProvider) Random(java.util.Random) DataDistribution(org.kie.kogito.explainability.model.DataDistribution) FeatureImportance(org.kie.kogito.explainability.model.FeatureImportance) PredictionOutput(org.kie.kogito.explainability.model.PredictionOutput) ValueSource(org.junit.jupiter.params.provider.ValueSource) ParameterizedTest(org.junit.jupiter.params.ParameterizedTest)

Aggregations

PredictionOutput (org.kie.kogito.explainability.model.PredictionOutput)155 PredictionInput (org.kie.kogito.explainability.model.PredictionInput)137 PredictionProvider (org.kie.kogito.explainability.model.PredictionProvider)124 Prediction (org.kie.kogito.explainability.model.Prediction)122 Random (java.util.Random)90 Test (org.junit.jupiter.api.Test)90 SimplePrediction (org.kie.kogito.explainability.model.SimplePrediction)89 Feature (org.kie.kogito.explainability.model.Feature)80 ArrayList (java.util.ArrayList)74 Output (org.kie.kogito.explainability.model.Output)65 PerturbationContext (org.kie.kogito.explainability.model.PerturbationContext)65 ParameterizedTest (org.junit.jupiter.params.ParameterizedTest)55 LimeConfig (org.kie.kogito.explainability.local.lime.LimeConfig)52 LimeExplainer (org.kie.kogito.explainability.local.lime.LimeExplainer)50 Saliency (org.kie.kogito.explainability.model.Saliency)48 Value (org.kie.kogito.explainability.model.Value)47 LinkedList (java.util.LinkedList)37 List (java.util.List)36 LimeConfigOptimizer (org.kie.kogito.explainability.local.lime.optim.LimeConfigOptimizer)33 ValueSource (org.junit.jupiter.params.provider.ValueSource)32