use of org.kie.kogito.explainability.model.Value in project kogito-apps by kiegroup.
the class FairnessMetricsTest method testGroupDIRTextClassifier.
@Test
void testGroupDIRTextClassifier() throws ExecutionException, InterruptedException {
List<PredictionInput> testInputs = getTestInputs();
PredictionProvider model = TestUtils.getDummyTextClassifier();
Predicate<PredictionInput> selector = predictionInput -> DataUtils.textify(predictionInput).contains("please");
Output output = new Output("spam", Type.BOOLEAN, new Value(false), 1.0);
double dir = FairnessMetrics.groupDisparateImpactRatio(selector, testInputs, model, output);
assertThat(dir).isPositive();
}
use of org.kie.kogito.explainability.model.Value in project kogito-apps by kiegroup.
the class MatrixUtilsExtensionsTest method testPOListCreation.
// test creation of matrix from list of PredictionOutputs
@Test
void testPOListCreation() {
// use the mat 3x5 as our list of prediction outputs
List<PredictionOutput> ps = new ArrayList<>();
for (int i = 0; i < 3; i++) {
List<Output> os = new ArrayList<>();
for (int j = 0; j < 5; j++) {
Value v = new Value(mat3X5[i][j]);
os.add(new Output("o", Type.NUMBER, v, 0.0));
}
ps.add(new PredictionOutput(os));
}
RealMatrix converted = MatrixUtilsExtensions.matrixFromPredictionOutput(ps);
assertArrayEquals(mat3X5, converted.getData());
}
use of org.kie.kogito.explainability.model.Value in project kogito-apps by kiegroup.
the class DecisionModelWrapper method predictAsync.
@Override
public CompletableFuture<List<PredictionOutput>> predictAsync(List<PredictionInput> inputs) {
List<PredictionOutput> predictionOutputs = new LinkedList<>();
for (PredictionInput input : inputs) {
Map<String, Object> contextVariables = toMap(input.getFeatures());
final DMNContext context = decisionModel.newContext(contextVariables);
DMNResult dmnResult = decisionModel.evaluateAll(context);
List<Output> outputs = new LinkedList<>();
for (DMNDecisionResult decisionResult : dmnResult.getDecisionResults()) {
String decisionName = decisionResult.getDecisionName();
if (!skippedDecisions.contains(decisionName)) {
Object result = decisionResult.getResult();
Value value = new Value(result);
Type type;
if (result == null) {
type = Type.TEXT;
} else {
if (result instanceof Boolean) {
type = Type.BOOLEAN;
} else if (result instanceof String) {
type = Type.TEXT;
} else {
type = Type.NUMBER;
}
}
Output output = new Output(decisionName, type, value, 1d);
outputs.add(output);
}
}
PredictionOutput predictionOutput = new PredictionOutput(outputs);
predictionOutputs.add(predictionOutput);
}
return completedFuture(predictionOutputs);
}
use of org.kie.kogito.explainability.model.Value in project kogito-apps by kiegroup.
the class ShapResultsTest method buildShapResults.
ShapResults buildShapResults(int nOutputs, int nFeatures, int scalar1, int scalar2) {
Saliency[] saliencies = new Saliency[nOutputs];
for (int i = 0; i < nOutputs; i++) {
List<FeatureImportance> fis = new ArrayList<>();
for (int j = 0; j < nFeatures; j++) {
fis.add(new FeatureImportance(new Feature("f" + String.valueOf(j), Type.NUMBER, new Value(j)), i * j * scalar1));
}
saliencies[i] = new Saliency(new Output("o" + String.valueOf(i), Type.NUMBER, new Value(i), 1.0), fis);
}
RealVector fnull = MatrixUtils.createRealVector(new double[nOutputs]);
fnull.mapAddToSelf(scalar2);
return new ShapResults(saliencies, fnull);
}
use of org.kie.kogito.explainability.model.Value in project kogito-apps by kiegroup.
the class DataUtilsTest method toCSV.
@Test
void toCSV() {
Feature feature = mock(Feature.class);
when(feature.getName()).thenReturn("feature-1");
Output output = mock(Output.class);
when(output.getName()).thenReturn("decision-1");
List<Value> x = new ArrayList<>();
x.add(new Value(1));
x.add(new Value(2));
x.add(new Value(3));
List<Value> y = new ArrayList<>();
y.add(new Value(4));
y.add(new Value(5));
y.add(new Value(4));
PartialDependenceGraph partialDependenceGraph = new PartialDependenceGraph(feature, output, x, y);
assertDoesNotThrow(() -> DataUtils.toCSV(partialDependenceGraph, Paths.get("target/test-pdp.csv")));
}
Aggregations