Search in sources :

Example 66 with PredictionInput

use of org.kie.kogito.explainability.model.PredictionInput in project kogito-apps by kiegroup.

the class FairnessMetricsTest method getTestInputs.

private List<PredictionInput> getTestInputs() {
    List<PredictionInput> inputs = new ArrayList<>();
    Function<String, List<String>> tokenizer = s -> Arrays.asList(s.split(" ").clone());
    List<Feature> features = new ArrayList<>();
    features.add(FeatureFactory.newFulltextFeature("subject", "urgent inquiry", tokenizer));
    features.add(FeatureFactory.newFulltextFeature("text", "please give me some money", tokenizer));
    inputs.add(new PredictionInput(features));
    features = new ArrayList<>();
    features.add(FeatureFactory.newFulltextFeature("subject", "please reply", tokenizer));
    features.add(FeatureFactory.newFulltextFeature("text", "we got urgent matter! please reply", tokenizer));
    inputs.add(new PredictionInput(features));
    features = new ArrayList<>();
    features.add(FeatureFactory.newFulltextFeature("subject", "please reply", tokenizer));
    features.add(FeatureFactory.newFulltextFeature("text", "we got money matter! please reply", tokenizer));
    inputs.add(new PredictionInput(features));
    features = new ArrayList<>();
    features.add(FeatureFactory.newFulltextFeature("subject", "inquiry", tokenizer));
    features.add(FeatureFactory.newFulltextFeature("text", "would you like to get a 100% secure way to invest your money?", tokenizer));
    inputs.add(new PredictionInput(features));
    features = new ArrayList<>();
    features.add(FeatureFactory.newFulltextFeature("subject", "you win", tokenizer));
    features.add(FeatureFactory.newFulltextFeature("text", "you just won an incredible 1M $ prize !", tokenizer));
    inputs.add(new PredictionInput(features));
    features = new ArrayList<>();
    features.add(FeatureFactory.newFulltextFeature("subject", "prize waiting", tokenizer));
    features.add(FeatureFactory.newFulltextFeature("text", "you are the lucky winner of a 100k $ prize", tokenizer));
    inputs.add(new PredictionInput(features));
    features = new ArrayList<>();
    features.add(FeatureFactory.newFulltextFeature("subject", "urgent matter", tokenizer));
    features.add(FeatureFactory.newFulltextFeature("text", "we got an urgent inquiry for you to answer.", tokenizer));
    inputs.add(new PredictionInput(features));
    features = new ArrayList<>();
    features.add(FeatureFactory.newFulltextFeature("subject", "password change", tokenizer));
    features.add(FeatureFactory.newFulltextFeature("text", "you just requested to change your password", tokenizer));
    inputs.add(new PredictionInput(features));
    features = new ArrayList<>();
    features.add(FeatureFactory.newFulltextFeature("subject", "password stolen", tokenizer));
    features.add(FeatureFactory.newFulltextFeature("text", "we stole your password, if you want it back, send some money .", tokenizer));
    inputs.add(new PredictionInput(features));
    return inputs;
}
Also used : FeatureFactory(org.kie.kogito.explainability.model.FeatureFactory) SimplePrediction(org.kie.kogito.explainability.model.SimplePrediction) Arrays(java.util.Arrays) Predicate(java.util.function.Predicate) Feature(org.kie.kogito.explainability.model.Feature) Prediction(org.kie.kogito.explainability.model.Prediction) BiFunction(java.util.function.BiFunction) Dataset(org.kie.kogito.explainability.model.Dataset) Value(org.kie.kogito.explainability.model.Value) Function(java.util.function.Function) Collectors(java.util.stream.Collectors) StringUtils(org.apache.commons.lang3.StringUtils) Type(org.kie.kogito.explainability.model.Type) PredictionProvider(org.kie.kogito.explainability.model.PredictionProvider) ArrayList(java.util.ArrayList) ExecutionException(java.util.concurrent.ExecutionException) Test(org.junit.jupiter.api.Test) PredictionInput(org.kie.kogito.explainability.model.PredictionInput) List(java.util.List) TestUtils(org.kie.kogito.explainability.TestUtils) Locale(java.util.Locale) Output(org.kie.kogito.explainability.model.Output) AssertionsForClassTypes.assertThat(org.assertj.core.api.AssertionsForClassTypes.assertThat) PredictionOutput(org.kie.kogito.explainability.model.PredictionOutput) PredictionInput(org.kie.kogito.explainability.model.PredictionInput) ArrayList(java.util.ArrayList) ArrayList(java.util.ArrayList) List(java.util.List) Feature(org.kie.kogito.explainability.model.Feature)

Example 67 with PredictionInput

use of org.kie.kogito.explainability.model.PredictionInput in project kogito-apps by kiegroup.

the class FairnessMetricsTest method testIndividualConsistencyTextClassifier.

@Test
void testIndividualConsistencyTextClassifier() throws ExecutionException, InterruptedException {
    BiFunction<PredictionInput, List<PredictionInput>, List<PredictionInput>> proximityFunction = (predictionInput, predictionInputs) -> {
        String reference = DataUtils.textify(predictionInput);
        return predictionInputs.stream().sorted((o1, o2) -> (StringUtils.getFuzzyDistance(DataUtils.textify(o2), reference, Locale.getDefault()) - StringUtils.getFuzzyDistance(DataUtils.textify(o1), reference, Locale.getDefault()))).collect(Collectors.toList()).subList(1, 3);
    };
    List<PredictionInput> testInputs = getTestInputs();
    PredictionProvider model = TestUtils.getDummyTextClassifier();
    double individualConsistency = FairnessMetrics.individualConsistency(proximityFunction, testInputs, model);
    assertThat(individualConsistency).isBetween(0d, 1d);
}
Also used : FeatureFactory(org.kie.kogito.explainability.model.FeatureFactory) SimplePrediction(org.kie.kogito.explainability.model.SimplePrediction) Arrays(java.util.Arrays) Predicate(java.util.function.Predicate) Feature(org.kie.kogito.explainability.model.Feature) Prediction(org.kie.kogito.explainability.model.Prediction) BiFunction(java.util.function.BiFunction) Dataset(org.kie.kogito.explainability.model.Dataset) Value(org.kie.kogito.explainability.model.Value) Function(java.util.function.Function) Collectors(java.util.stream.Collectors) StringUtils(org.apache.commons.lang3.StringUtils) Type(org.kie.kogito.explainability.model.Type) PredictionProvider(org.kie.kogito.explainability.model.PredictionProvider) ArrayList(java.util.ArrayList) ExecutionException(java.util.concurrent.ExecutionException) Test(org.junit.jupiter.api.Test) PredictionInput(org.kie.kogito.explainability.model.PredictionInput) List(java.util.List) TestUtils(org.kie.kogito.explainability.TestUtils) Locale(java.util.Locale) Output(org.kie.kogito.explainability.model.Output) AssertionsForClassTypes.assertThat(org.assertj.core.api.AssertionsForClassTypes.assertThat) PredictionOutput(org.kie.kogito.explainability.model.PredictionOutput) PredictionInput(org.kie.kogito.explainability.model.PredictionInput) ArrayList(java.util.ArrayList) List(java.util.List) PredictionProvider(org.kie.kogito.explainability.model.PredictionProvider) Test(org.junit.jupiter.api.Test)

Example 68 with PredictionInput

use of org.kie.kogito.explainability.model.PredictionInput in project kogito-apps by kiegroup.

the class MatrixUtilsExtensionsTest method testPICreation.

// === Matrix creation tests =======================================================================================
// test creation of matrix from single PredictionInput
@Test
void testPICreation() {
    // use the mat 3x5 to grab one row for prediction input
    List<Feature> fs = new ArrayList<>();
    for (int j = 0; j < 5; j++) {
        fs.add(FeatureFactory.newNumericalFeature("f", mat3X5[0][j]));
    }
    PredictionInput pi = new PredictionInput(fs);
    RealVector converted = MatrixUtilsExtensions.vectorFromPredictionInput(pi);
    assertArrayEquals(mat3X5[0], converted.toArray());
}
Also used : PredictionInput(org.kie.kogito.explainability.model.PredictionInput) RealVector(org.apache.commons.math3.linear.RealVector) ArrayList(java.util.ArrayList) Feature(org.kie.kogito.explainability.model.Feature) Test(org.junit.jupiter.api.Test)

Example 69 with PredictionInput

use of org.kie.kogito.explainability.model.PredictionInput in project kogito-apps by kiegroup.

the class MatrixUtilsExtensionsTest method testPIListCreation.

// test creation of matrix from list of PredictionInputs
@Test
void testPIListCreation() {
    // use the mat 3x5 as our list of prediction inputs
    List<PredictionInput> ps = new ArrayList<>();
    for (int i = 0; i < 3; i++) {
        List<Feature> fs = new ArrayList<>();
        for (int j = 0; j < 5; j++) {
            fs.add(FeatureFactory.newNumericalFeature("f", mat3X5[i][j]));
        }
        ps.add(new PredictionInput(fs));
    }
    RealMatrix converted = MatrixUtilsExtensions.matrixFromPredictionInput(ps);
    assertArrayEquals(mat3X5, converted.getData());
}
Also used : RealMatrix(org.apache.commons.math3.linear.RealMatrix) PredictionInput(org.kie.kogito.explainability.model.PredictionInput) ArrayList(java.util.ArrayList) Feature(org.kie.kogito.explainability.model.Feature) Test(org.junit.jupiter.api.Test)

Example 70 with PredictionInput

use of org.kie.kogito.explainability.model.PredictionInput in project kogito-apps by kiegroup.

the class ComplexEligibilityDmnCounterfactualExplainerTest method testDMNScoringFunction.

@Test
void testDMNScoringFunction() throws ExecutionException, InterruptedException, TimeoutException {
    PredictionProvider model = getModel();
    final List<Output> goal = generateGoal(true, true, 1.0);
    List<Feature> features = new LinkedList<>();
    features.add(FeatureFactory.newNumericalFeature("age", 40, NumericalFeatureDomain.create(18, 60)));
    features.add(FeatureFactory.newBooleanFeature("hasReferral", true));
    features.add(FeatureFactory.newNumericalFeature("monthlySalary", 500, NumericalFeatureDomain.create(10, 100_000)));
    final TerminationConfig terminationConfig = new TerminationConfig().withScoreCalculationCountLimit(10_000L);
    // for the purpose of this test, only a few steps are necessary
    final SolverConfig solverConfig = SolverConfigBuilder.builder().withTerminationConfig(terminationConfig).build();
    solverConfig.setRandomSeed((long) 23);
    solverConfig.setEnvironmentMode(EnvironmentMode.REPRODUCIBLE);
    final CounterfactualConfig counterfactualConfig = new CounterfactualConfig().withSolverConfig(solverConfig).withGoalThreshold(0.01);
    final CounterfactualExplainer counterfactualExplainer = new CounterfactualExplainer(counterfactualConfig);
    PredictionInput input = new PredictionInput(features);
    PredictionOutput output = new PredictionOutput(goal);
    Prediction prediction = new CounterfactualPrediction(input, output, null, UUID.randomUUID(), 60L);
    final CounterfactualResult counterfactualResult = counterfactualExplainer.explainAsync(prediction, model).get(Config.INSTANCE.getAsyncTimeout(), Config.INSTANCE.getAsyncTimeUnit());
    List<Output> cfOutputs = counterfactualResult.getOutput().get(0).getOutputs();
    assertTrue(counterfactualResult.isValid());
    assertEquals("inputsAreValid", cfOutputs.get(0).getName());
    assertTrue((Boolean) cfOutputs.get(0).getValue().getUnderlyingObject());
    assertEquals("canRequestLoan", cfOutputs.get(1).getName());
    assertTrue((Boolean) cfOutputs.get(1).getValue().getUnderlyingObject());
    assertEquals("my-scoring-function", cfOutputs.get(2).getName());
    assertEquals(1.0, ((BigDecimal) cfOutputs.get(2).getValue().getUnderlyingObject()).doubleValue(), 0.01);
    List<CounterfactualEntity> entities = counterfactualResult.getEntities();
    assertEquals("age", entities.get(0).asFeature().getName());
    assertEquals(18, entities.get(0).asFeature().getValue().asNumber());
    assertEquals("hasReferral", entities.get(1).asFeature().getName());
    assertTrue((Boolean) entities.get(1).asFeature().getValue().getUnderlyingObject());
    assertEquals("monthlySalary", entities.get(2).asFeature().getName());
    final double monthlySalary = entities.get(2).asFeature().getValue().asNumber();
    assertEquals(7900, monthlySalary, 10);
    // since the scoring function is ((0.6 * ((42 - age + 18)/42)) + (0.4 * (monthlySalary/8000)))
    // for a result of 1.0 the relation must be age = (7*monthlySalary)/2000 - 10
    assertEquals(18, (7 * monthlySalary) / 2000.0 - 10.0, 0.5);
}
Also used : PredictionInput(org.kie.kogito.explainability.model.PredictionInput) Prediction(org.kie.kogito.explainability.model.Prediction) CounterfactualPrediction(org.kie.kogito.explainability.model.CounterfactualPrediction) PredictionProvider(org.kie.kogito.explainability.model.PredictionProvider) Feature(org.kie.kogito.explainability.model.Feature) LinkedList(java.util.LinkedList) CounterfactualResult(org.kie.kogito.explainability.local.counterfactual.CounterfactualResult) CounterfactualPrediction(org.kie.kogito.explainability.model.CounterfactualPrediction) CounterfactualEntity(org.kie.kogito.explainability.local.counterfactual.entities.CounterfactualEntity) TerminationConfig(org.optaplanner.core.config.solver.termination.TerminationConfig) PredictionOutput(org.kie.kogito.explainability.model.PredictionOutput) PredictionOutput(org.kie.kogito.explainability.model.PredictionOutput) Output(org.kie.kogito.explainability.model.Output) CounterfactualConfig(org.kie.kogito.explainability.local.counterfactual.CounterfactualConfig) CounterfactualExplainer(org.kie.kogito.explainability.local.counterfactual.CounterfactualExplainer) SolverConfig(org.optaplanner.core.config.solver.SolverConfig) Test(org.junit.jupiter.api.Test)

Aggregations

PredictionInput (org.kie.kogito.explainability.model.PredictionInput)187 PredictionOutput (org.kie.kogito.explainability.model.PredictionOutput)143 PredictionProvider (org.kie.kogito.explainability.model.PredictionProvider)135 Prediction (org.kie.kogito.explainability.model.Prediction)126 Feature (org.kie.kogito.explainability.model.Feature)109 Test (org.junit.jupiter.api.Test)107 ArrayList (java.util.ArrayList)97 SimplePrediction (org.kie.kogito.explainability.model.SimplePrediction)95 Random (java.util.Random)86 PerturbationContext (org.kie.kogito.explainability.model.PerturbationContext)67 ParameterizedTest (org.junit.jupiter.params.ParameterizedTest)60 Output (org.kie.kogito.explainability.model.Output)55 LimeConfig (org.kie.kogito.explainability.local.lime.LimeConfig)54 LinkedList (java.util.LinkedList)53 LimeExplainer (org.kie.kogito.explainability.local.lime.LimeExplainer)52 Value (org.kie.kogito.explainability.model.Value)52 Saliency (org.kie.kogito.explainability.model.Saliency)50 List (java.util.List)39 LimeConfigOptimizer (org.kie.kogito.explainability.local.lime.optim.LimeConfigOptimizer)33 Type (org.kie.kogito.explainability.model.Type)31