Search in sources :

Example 1 with Dataset

use of org.kie.kogito.explainability.model.Dataset in project kogito-apps by kiegroup.

the class FairnessMetrics method groupAverageOddsDifference.

/**
 * Calculate average odds difference.
 *
 * @param inputSelector selector for privileged group
 * @param outputSelector selector for favorable label
 * @param dataset dataset used to evaluate AOD
 * @param model model to be evaluated fairness-wise
 * @return average odds difference value
 * @throws ExecutionException if any error occurs during model prediction
 * @throws InterruptedException if timeout or other interruption issues occur during model prediction
 */
public static double groupAverageOddsDifference(Predicate<PredictionInput> inputSelector, Predicate<PredictionOutput> outputSelector, Dataset dataset, PredictionProvider model) throws ExecutionException, InterruptedException {
    Dataset privileged = dataset.filterByInput(inputSelector);
    Map<String, Integer> privilegedCounts = countMatchingOutputSelector(privileged, model.predictAsync(privileged.getInputs()).get(), outputSelector);
    Dataset unprivileged = dataset.filterByInput(inputSelector.negate());
    Map<String, Integer> unprivilegedCounts = countMatchingOutputSelector(unprivileged, model.predictAsync(unprivileged.getInputs()).get(), outputSelector);
    double utp = unprivilegedCounts.get("tp");
    double utn = unprivilegedCounts.get("tn");
    double ufp = unprivilegedCounts.get("fp");
    double ufn = unprivilegedCounts.get("fn");
    double ptp = privilegedCounts.get("tp");
    double ptn = privilegedCounts.get("tn");
    double pfp = privilegedCounts.get("fp");
    double pfn = privilegedCounts.get("fn");
    return (utp / (utp + ufn) - ptp / (ptp + pfn + 1e-10)) / 2d + (ufp / (ufp + utn) - pfp / (pfp + ptn + 1e-10)) / 2;
}
Also used : Dataset(org.kie.kogito.explainability.model.Dataset)

Example 2 with Dataset

use of org.kie.kogito.explainability.model.Dataset in project kogito-apps by kiegroup.

the class FairnessMetrics method groupAveragePredictiveValueDifference.

/**
 * Calculate average predictive value difference.
 *
 * @param inputSelector selector for privileged group
 * @param outputSelector selector for favorable label
 * @param dataset dataset used to evaluate AOD
 * @param model model to be evaluated fairness-wise
 * @return average predictive value difference
 * @throws ExecutionException if any error occurs during model prediction
 * @throws InterruptedException if timeout or other interruption issues occur during model prediction
 */
public static double groupAveragePredictiveValueDifference(Predicate<PredictionInput> inputSelector, Predicate<PredictionOutput> outputSelector, Dataset dataset, PredictionProvider model) throws ExecutionException, InterruptedException {
    Dataset privileged = dataset.filterByInput(inputSelector);
    Map<String, Integer> privilegedCounts = countMatchingOutputSelector(privileged, model.predictAsync(privileged.getInputs()).get(), outputSelector);
    double ptp = privilegedCounts.get("tp");
    double ptn = privilegedCounts.get("tn");
    double pfp = privilegedCounts.get("fp");
    double pfn = privilegedCounts.get("fn");
    Dataset unprivileged = dataset.filterByInput(inputSelector.negate());
    Map<String, Integer> unprivilegedCounts = countMatchingOutputSelector(unprivileged, model.predictAsync(unprivileged.getInputs()).get(), outputSelector);
    double utp = unprivilegedCounts.get("tp");
    double utn = unprivilegedCounts.get("tn");
    double ufp = unprivilegedCounts.get("fp");
    double ufn = unprivilegedCounts.get("fn");
    return (utp / (utp + ufp) - ptp / (ptp + pfp + 1e-10)) / 2d + (ufn / (ufn + utn) - pfn / (pfn + ptn + 1e-10)) / 2;
}
Also used : Dataset(org.kie.kogito.explainability.model.Dataset)

Example 3 with Dataset

use of org.kie.kogito.explainability.model.Dataset in project kogito-apps by kiegroup.

the class FairnessMetricsTest method testGroupAPVDTextClassifier.

@Test
void testGroupAPVDTextClassifier() throws ExecutionException, InterruptedException {
    List<Prediction> predictions = getTestData();
    Dataset dataset = new Dataset(predictions);
    PredictionProvider model = TestUtils.getDummyTextClassifier();
    Predicate<PredictionInput> inputSelector = predictionInput -> DataUtils.textify(predictionInput).contains("please");
    Predicate<PredictionOutput> outputSelector = predictionOutput -> predictionOutput.getByName("spam").get().getValue().asNumber() == 0;
    double apvd = FairnessMetrics.groupAveragePredictiveValueDifference(inputSelector, outputSelector, dataset, model);
    assertThat(apvd).isBetween(-1d, 1d);
}
Also used : FeatureFactory(org.kie.kogito.explainability.model.FeatureFactory) SimplePrediction(org.kie.kogito.explainability.model.SimplePrediction) Arrays(java.util.Arrays) Predicate(java.util.function.Predicate) Feature(org.kie.kogito.explainability.model.Feature) Prediction(org.kie.kogito.explainability.model.Prediction) BiFunction(java.util.function.BiFunction) Dataset(org.kie.kogito.explainability.model.Dataset) Value(org.kie.kogito.explainability.model.Value) Function(java.util.function.Function) Collectors(java.util.stream.Collectors) StringUtils(org.apache.commons.lang3.StringUtils) Type(org.kie.kogito.explainability.model.Type) PredictionProvider(org.kie.kogito.explainability.model.PredictionProvider) ArrayList(java.util.ArrayList) ExecutionException(java.util.concurrent.ExecutionException) Test(org.junit.jupiter.api.Test) PredictionInput(org.kie.kogito.explainability.model.PredictionInput) List(java.util.List) TestUtils(org.kie.kogito.explainability.TestUtils) Locale(java.util.Locale) Output(org.kie.kogito.explainability.model.Output) AssertionsForClassTypes.assertThat(org.assertj.core.api.AssertionsForClassTypes.assertThat) PredictionOutput(org.kie.kogito.explainability.model.PredictionOutput) PredictionInput(org.kie.kogito.explainability.model.PredictionInput) Dataset(org.kie.kogito.explainability.model.Dataset) PredictionOutput(org.kie.kogito.explainability.model.PredictionOutput) SimplePrediction(org.kie.kogito.explainability.model.SimplePrediction) Prediction(org.kie.kogito.explainability.model.Prediction) PredictionProvider(org.kie.kogito.explainability.model.PredictionProvider) Test(org.junit.jupiter.api.Test)

Example 4 with Dataset

use of org.kie.kogito.explainability.model.Dataset in project kogito-apps by kiegroup.

the class FairnessMetricsTest method testGroupAODTextClassifier.

@Test
void testGroupAODTextClassifier() throws ExecutionException, InterruptedException {
    List<Prediction> predictions = getTestData();
    Dataset dataset = new Dataset(predictions);
    PredictionProvider model = TestUtils.getDummyTextClassifier();
    Predicate<PredictionInput> inputSelector = predictionInput -> DataUtils.textify(predictionInput).contains("please");
    Predicate<PredictionOutput> outputSelector = predictionOutput -> predictionOutput.getByName("spam").get().getValue().asNumber() == 0;
    double aod = FairnessMetrics.groupAverageOddsDifference(inputSelector, outputSelector, dataset, model);
    assertThat(aod).isBetween(-1d, 1d);
}
Also used : FeatureFactory(org.kie.kogito.explainability.model.FeatureFactory) SimplePrediction(org.kie.kogito.explainability.model.SimplePrediction) Arrays(java.util.Arrays) Predicate(java.util.function.Predicate) Feature(org.kie.kogito.explainability.model.Feature) Prediction(org.kie.kogito.explainability.model.Prediction) BiFunction(java.util.function.BiFunction) Dataset(org.kie.kogito.explainability.model.Dataset) Value(org.kie.kogito.explainability.model.Value) Function(java.util.function.Function) Collectors(java.util.stream.Collectors) StringUtils(org.apache.commons.lang3.StringUtils) Type(org.kie.kogito.explainability.model.Type) PredictionProvider(org.kie.kogito.explainability.model.PredictionProvider) ArrayList(java.util.ArrayList) ExecutionException(java.util.concurrent.ExecutionException) Test(org.junit.jupiter.api.Test) PredictionInput(org.kie.kogito.explainability.model.PredictionInput) List(java.util.List) TestUtils(org.kie.kogito.explainability.TestUtils) Locale(java.util.Locale) Output(org.kie.kogito.explainability.model.Output) AssertionsForClassTypes.assertThat(org.assertj.core.api.AssertionsForClassTypes.assertThat) PredictionOutput(org.kie.kogito.explainability.model.PredictionOutput) PredictionInput(org.kie.kogito.explainability.model.PredictionInput) Dataset(org.kie.kogito.explainability.model.Dataset) PredictionOutput(org.kie.kogito.explainability.model.PredictionOutput) SimplePrediction(org.kie.kogito.explainability.model.SimplePrediction) Prediction(org.kie.kogito.explainability.model.Prediction) PredictionProvider(org.kie.kogito.explainability.model.PredictionProvider) Test(org.junit.jupiter.api.Test)

Aggregations

Dataset (org.kie.kogito.explainability.model.Dataset)4 ArrayList (java.util.ArrayList)2 Arrays (java.util.Arrays)2 List (java.util.List)2 Locale (java.util.Locale)2 ExecutionException (java.util.concurrent.ExecutionException)2 BiFunction (java.util.function.BiFunction)2 Function (java.util.function.Function)2 Predicate (java.util.function.Predicate)2 Collectors (java.util.stream.Collectors)2 StringUtils (org.apache.commons.lang3.StringUtils)2 AssertionsForClassTypes.assertThat (org.assertj.core.api.AssertionsForClassTypes.assertThat)2 Test (org.junit.jupiter.api.Test)2 TestUtils (org.kie.kogito.explainability.TestUtils)2 Feature (org.kie.kogito.explainability.model.Feature)2 FeatureFactory (org.kie.kogito.explainability.model.FeatureFactory)2 Output (org.kie.kogito.explainability.model.Output)2 Prediction (org.kie.kogito.explainability.model.Prediction)2 PredictionInput (org.kie.kogito.explainability.model.PredictionInput)2 PredictionOutput (org.kie.kogito.explainability.model.PredictionOutput)2