Search in sources :

Example 76 with Value

use of org.kie.kogito.explainability.model.Value in project kogito-apps by kiegroup.

the class FairnessMetricsTest method getTestData.

private List<Prediction> getTestData() {
    List<Prediction> data = new ArrayList<>();
    Function<String, List<String>> tokenizer = s -> Arrays.asList(s.split(" ").clone());
    List<Feature> features = new ArrayList<>();
    features.add(FeatureFactory.newFulltextFeature("subject", "urgent inquiry", tokenizer));
    features.add(FeatureFactory.newFulltextFeature("text", "please give me some money", tokenizer));
    Output output = new Output("spam", Type.BOOLEAN, new Value(true), 1);
    data.add(new SimplePrediction(new PredictionInput(features), new PredictionOutput(List.of(output))));
    features = new ArrayList<>();
    features.add(FeatureFactory.newFulltextFeature("subject", "do not reply", tokenizer));
    features.add(FeatureFactory.newFulltextFeature("text", "if you asked to reset your password, ignore this", tokenizer));
    output = new Output("spam", Type.BOOLEAN, new Value(false), 1);
    data.add(new SimplePrediction(new PredictionInput(features), new PredictionOutput(List.of(output))));
    features = new ArrayList<>();
    features.add(FeatureFactory.newFulltextFeature("subject", "please reply", tokenizer));
    features.add(FeatureFactory.newFulltextFeature("text", "we got money matter! please reply", tokenizer));
    output = new Output("spam", Type.BOOLEAN, new Value(true), 1);
    data.add(new SimplePrediction(new PredictionInput(features), new PredictionOutput(List.of(output))));
    features = new ArrayList<>();
    features.add(FeatureFactory.newFulltextFeature("subject", "inquiry", tokenizer));
    features.add(FeatureFactory.newFulltextFeature("text", "would you like to get a 100% secure way to invest your money?", tokenizer));
    output = new Output("spam", Type.BOOLEAN, new Value(true), 1);
    data.add(new SimplePrediction(new PredictionInput(features), new PredictionOutput(List.of(output))));
    features = new ArrayList<>();
    features.add(FeatureFactory.newFulltextFeature("subject", "clear some space", tokenizer));
    features.add(FeatureFactory.newFulltextFeature("text", "you just finished your space, upgrade today for 1 $ a week", tokenizer));
    output = new Output("spam", Type.BOOLEAN, new Value(false), 1);
    data.add(new SimplePrediction(new PredictionInput(features), new PredictionOutput(List.of(output))));
    features = new ArrayList<>();
    features.add(FeatureFactory.newFulltextFeature("subject", "prize waiting", tokenizer));
    features.add(FeatureFactory.newFulltextFeature("text", "you are the lucky winner of a 100k $ prize", tokenizer));
    output = new Output("spam", Type.BOOLEAN, new Value(true), 1);
    data.add(new SimplePrediction(new PredictionInput(features), new PredictionOutput(List.of(output))));
    features = new ArrayList<>();
    features.add(FeatureFactory.newFulltextFeature("subject", "urgent matter", tokenizer));
    features.add(FeatureFactory.newFulltextFeature("text", "we got an urgent inquiry for you to answer.", tokenizer));
    output = new Output("spam", Type.BOOLEAN, new Value(true), 1);
    data.add(new SimplePrediction(new PredictionInput(features), new PredictionOutput(List.of(output))));
    features = new ArrayList<>();
    features.add(FeatureFactory.newFulltextFeature("subject", "password change", tokenizer));
    features.add(FeatureFactory.newFulltextFeature("text", "you just requested to change your password", tokenizer));
    output = new Output("spam", Type.BOOLEAN, new Value(false), 1);
    data.add(new SimplePrediction(new PredictionInput(features), new PredictionOutput(List.of(output))));
    features = new ArrayList<>();
    features.add(FeatureFactory.newFulltextFeature("subject", "password stolen", tokenizer));
    features.add(FeatureFactory.newFulltextFeature("text", "we stole your password, if you want it back, send some money .", tokenizer));
    output = new Output("spam", Type.BOOLEAN, new Value(true), 1);
    data.add(new SimplePrediction(new PredictionInput(features), new PredictionOutput(List.of(output))));
    return data;
}
Also used : FeatureFactory(org.kie.kogito.explainability.model.FeatureFactory) SimplePrediction(org.kie.kogito.explainability.model.SimplePrediction) Arrays(java.util.Arrays) Predicate(java.util.function.Predicate) Feature(org.kie.kogito.explainability.model.Feature) Prediction(org.kie.kogito.explainability.model.Prediction) BiFunction(java.util.function.BiFunction) Dataset(org.kie.kogito.explainability.model.Dataset) Value(org.kie.kogito.explainability.model.Value) Function(java.util.function.Function) Collectors(java.util.stream.Collectors) StringUtils(org.apache.commons.lang3.StringUtils) Type(org.kie.kogito.explainability.model.Type) PredictionProvider(org.kie.kogito.explainability.model.PredictionProvider) ArrayList(java.util.ArrayList) ExecutionException(java.util.concurrent.ExecutionException) Test(org.junit.jupiter.api.Test) PredictionInput(org.kie.kogito.explainability.model.PredictionInput) List(java.util.List) TestUtils(org.kie.kogito.explainability.TestUtils) Locale(java.util.Locale) Output(org.kie.kogito.explainability.model.Output) AssertionsForClassTypes.assertThat(org.assertj.core.api.AssertionsForClassTypes.assertThat) PredictionOutput(org.kie.kogito.explainability.model.PredictionOutput) SimplePrediction(org.kie.kogito.explainability.model.SimplePrediction) PredictionInput(org.kie.kogito.explainability.model.PredictionInput) SimplePrediction(org.kie.kogito.explainability.model.SimplePrediction) Prediction(org.kie.kogito.explainability.model.Prediction) ArrayList(java.util.ArrayList) Feature(org.kie.kogito.explainability.model.Feature) PredictionOutput(org.kie.kogito.explainability.model.PredictionOutput) Output(org.kie.kogito.explainability.model.Output) PredictionOutput(org.kie.kogito.explainability.model.PredictionOutput) Value(org.kie.kogito.explainability.model.Value) ArrayList(java.util.ArrayList) List(java.util.List)

Example 77 with Value

use of org.kie.kogito.explainability.model.Value in project kogito-apps by kiegroup.

the class FairnessMetricsTest method testGroupSPDTextClassifier.

@Test
void testGroupSPDTextClassifier() throws ExecutionException, InterruptedException {
    List<PredictionInput> testInputs = getTestInputs();
    PredictionProvider model = TestUtils.getDummyTextClassifier();
    Predicate<PredictionInput> selector = predictionInput -> DataUtils.textify(predictionInput).contains("please");
    Output output = new Output("spam", Type.BOOLEAN, new Value(false), 1.0);
    double spd = FairnessMetrics.groupStatisticalParityDifference(selector, testInputs, model, output);
    assertThat(spd).isBetween(-1d, 1d);
}
Also used : FeatureFactory(org.kie.kogito.explainability.model.FeatureFactory) SimplePrediction(org.kie.kogito.explainability.model.SimplePrediction) Arrays(java.util.Arrays) Predicate(java.util.function.Predicate) Feature(org.kie.kogito.explainability.model.Feature) Prediction(org.kie.kogito.explainability.model.Prediction) BiFunction(java.util.function.BiFunction) Dataset(org.kie.kogito.explainability.model.Dataset) Value(org.kie.kogito.explainability.model.Value) Function(java.util.function.Function) Collectors(java.util.stream.Collectors) StringUtils(org.apache.commons.lang3.StringUtils) Type(org.kie.kogito.explainability.model.Type) PredictionProvider(org.kie.kogito.explainability.model.PredictionProvider) ArrayList(java.util.ArrayList) ExecutionException(java.util.concurrent.ExecutionException) Test(org.junit.jupiter.api.Test) PredictionInput(org.kie.kogito.explainability.model.PredictionInput) List(java.util.List) TestUtils(org.kie.kogito.explainability.TestUtils) Locale(java.util.Locale) Output(org.kie.kogito.explainability.model.Output) AssertionsForClassTypes.assertThat(org.assertj.core.api.AssertionsForClassTypes.assertThat) PredictionOutput(org.kie.kogito.explainability.model.PredictionOutput) PredictionInput(org.kie.kogito.explainability.model.PredictionInput) Output(org.kie.kogito.explainability.model.Output) PredictionOutput(org.kie.kogito.explainability.model.PredictionOutput) Value(org.kie.kogito.explainability.model.Value) PredictionProvider(org.kie.kogito.explainability.model.PredictionProvider) Test(org.junit.jupiter.api.Test)

Example 78 with Value

use of org.kie.kogito.explainability.model.Value in project kogito-apps by kiegroup.

the class MatrixUtilsExtensionsTest method testPOCreation.

// test creation of matrix from single PredictionOutput
@Test
void testPOCreation() {
    // use the mat 3x5 as our list of prediction inputs
    List<Output> os = new ArrayList<>();
    for (int j = 0; j < 5; j++) {
        Value v = new Value(mat3X5[0][j]);
        os.add(new Output("o", Type.NUMBER, v, 0.0));
    }
    PredictionOutput po = new PredictionOutput(os);
    RealVector converted = MatrixUtilsExtensions.vectorFromPredictionOutput(po);
    assertArrayEquals(mat3X5[0], converted.toArray());
}
Also used : PredictionOutput(org.kie.kogito.explainability.model.PredictionOutput) RealVector(org.apache.commons.math3.linear.RealVector) Output(org.kie.kogito.explainability.model.Output) PredictionOutput(org.kie.kogito.explainability.model.PredictionOutput) ArrayList(java.util.ArrayList) Value(org.kie.kogito.explainability.model.Value) Test(org.junit.jupiter.api.Test)

Example 79 with Value

use of org.kie.kogito.explainability.model.Value in project kogito-apps by kiegroup.

the class PrequalificationDmnCounterfactualExplainerTest method testValidCounterfactual.

@Test
void testValidCounterfactual() throws ExecutionException, InterruptedException, TimeoutException {
    PredictionProvider model = getModel();
    final List<Output> goal = List.of(new Output("Qualified?", Type.BOOLEAN, new Value(true), 0.0d));
    final TerminationConfig terminationConfig = new TerminationConfig().withScoreCalculationCountLimit(steps);
    final SolverConfig solverConfig = SolverConfigBuilder.builder().withTerminationConfig(terminationConfig).build();
    solverConfig.setRandomSeed(randomSeed);
    solverConfig.setEnvironmentMode(EnvironmentMode.REPRODUCIBLE);
    CounterfactualConfig config = new CounterfactualConfig().withGoalThreshold(0.1);
    config.withSolverConfig(solverConfig);
    final CounterfactualExplainer explainer = new CounterfactualExplainer(config);
    PredictionInput input = getTestInputVariable();
    PredictionOutput output = new PredictionOutput(goal);
    // test model
    List<PredictionOutput> predictionOutputs = model.predictAsync(List.of(getTestInputFixed())).get(Config.INSTANCE.getAsyncTimeout(), Config.INSTANCE.getAsyncTimeUnit());
    final Output predictionOutput = predictionOutputs.get(0).getOutputs().get(0);
    assertEquals("Qualified?", predictionOutput.getName());
    assertFalse((Boolean) predictionOutput.getValue().getUnderlyingObject());
    Prediction prediction = new CounterfactualPrediction(input, output, null, UUID.randomUUID(), null);
    CounterfactualResult counterfactualResult = explainer.explainAsync(prediction, model).get();
    List<Feature> cfFeatures = counterfactualResult.getEntities().stream().map(CounterfactualEntity::asFeature).collect(Collectors.toList());
    List<Feature> unflattened = CompositeFeatureUtils.unflattenFeatures(cfFeatures, input.getFeatures());
    List<PredictionOutput> outputs = model.predictAsync(List.of(new PredictionInput(unflattened))).get();
    assertTrue(counterfactualResult.isValid());
    final Output decideOutput = outputs.get(0).getOutputs().get(0);
    assertEquals("Qualified?", decideOutput.getName());
    assertTrue((Boolean) decideOutput.getValue().getUnderlyingObject());
}
Also used : PredictionInput(org.kie.kogito.explainability.model.PredictionInput) Prediction(org.kie.kogito.explainability.model.Prediction) CounterfactualPrediction(org.kie.kogito.explainability.model.CounterfactualPrediction) PredictionProvider(org.kie.kogito.explainability.model.PredictionProvider) Feature(org.kie.kogito.explainability.model.Feature) CounterfactualResult(org.kie.kogito.explainability.local.counterfactual.CounterfactualResult) CounterfactualPrediction(org.kie.kogito.explainability.model.CounterfactualPrediction) TerminationConfig(org.optaplanner.core.config.solver.termination.TerminationConfig) PredictionOutput(org.kie.kogito.explainability.model.PredictionOutput) PredictionOutput(org.kie.kogito.explainability.model.PredictionOutput) Output(org.kie.kogito.explainability.model.Output) CounterfactualConfig(org.kie.kogito.explainability.local.counterfactual.CounterfactualConfig) Value(org.kie.kogito.explainability.model.Value) CounterfactualExplainer(org.kie.kogito.explainability.local.counterfactual.CounterfactualExplainer) SolverConfig(org.optaplanner.core.config.solver.SolverConfig) Test(org.junit.jupiter.api.Test)

Example 80 with Value

use of org.kie.kogito.explainability.model.Value in project kogito-apps by kiegroup.

the class ConversionUtils method toTypedValue.

static TypedValue toTypedValue(Feature feature) {
    String name = feature.getName();
    Type type = feature.getType();
    Value value = feature.getValue();
    return toTypedValue(name, type, value);
}
Also used : Type(org.kie.kogito.explainability.model.Type) CounterfactualSearchDomainValue(org.kie.kogito.explainability.api.CounterfactualSearchDomainValue) Value(org.kie.kogito.explainability.model.Value) HasNameValue(org.kie.kogito.explainability.api.HasNameValue) NamedTypedValue(org.kie.kogito.explainability.api.NamedTypedValue) TypedValue(org.kie.kogito.tracing.typedvalue.TypedValue) UnitValue(org.kie.kogito.tracing.typedvalue.UnitValue) CollectionValue(org.kie.kogito.tracing.typedvalue.CollectionValue)

Aggregations

Value (org.kie.kogito.explainability.model.Value)80 Feature (org.kie.kogito.explainability.model.Feature)69 Output (org.kie.kogito.explainability.model.Output)59 PredictionOutput (org.kie.kogito.explainability.model.PredictionOutput)54 PredictionInput (org.kie.kogito.explainability.model.PredictionInput)49 ArrayList (java.util.ArrayList)42 PredictionProvider (org.kie.kogito.explainability.model.PredictionProvider)42 LinkedList (java.util.LinkedList)36 Type (org.kie.kogito.explainability.model.Type)36 Test (org.junit.jupiter.api.Test)35 List (java.util.List)33 Prediction (org.kie.kogito.explainability.model.Prediction)33 Random (java.util.Random)31 ParameterizedTest (org.junit.jupiter.params.ParameterizedTest)23 Arrays (java.util.Arrays)16 Map (java.util.Map)16 Optional (java.util.Optional)16 CounterfactualEntity (org.kie.kogito.explainability.local.counterfactual.entities.CounterfactualEntity)16 FeatureFactory (org.kie.kogito.explainability.model.FeatureFactory)16 Collectors (java.util.stream.Collectors)15