use of org.kie.kogito.explainability.model.PartialDependenceGraph in project kogito-apps by kiegroup.
the class PartialDependencePlotExplainerTest method testTextClassifier.
@ParameterizedTest
@ValueSource(ints = { 0, 1, 2, 3, 4 })
void testTextClassifier(int seed) throws Exception {
Random random = new Random();
random.setSeed(seed);
PartialDependencePlotExplainer partialDependencePlotExplainer = new PartialDependencePlotExplainer();
PredictionProvider model = TestUtils.getDummyTextClassifier();
Collection<Prediction> predictions = new ArrayList<>(3);
List<String> texts = List.of("we want your money", "please reply quickly", "you are the lucky winner", "huge donation for you!", "bitcoin for you");
for (String text : texts) {
List<Feature> features = new ArrayList<>();
features.add(FeatureFactory.newFulltextFeature("text", text));
PredictionInput predictionInput = new PredictionInput(features);
PredictionOutput predictionOutput = model.predictAsync(List.of(predictionInput)).get().get(0);
predictions.add(new SimplePrediction(predictionInput, predictionOutput));
}
List<PartialDependenceGraph> pdps = partialDependencePlotExplainer.explainFromPredictions(model, predictions);
assertThat(pdps).isNotEmpty();
}
use of org.kie.kogito.explainability.model.PartialDependenceGraph in project kogito-apps by kiegroup.
the class PartialDependencePlotExplainer method getPartialDependenceGraph.
private PartialDependenceGraph getPartialDependenceGraph(PredictionProvider model, List<PredictionInput> trainingData, List<Value> xsValues, List<Feature> featureXSvalues, int outputIndex) throws InterruptedException, ExecutionException, TimeoutException {
Output outputDecision = null;
Feature feature = null;
// each feature value of the feature under analysis should have a corresponding output value (composed by the marginal impacts of the other features)
List<Map<Value, Long>> valueCounts = new ArrayList<>(featureXSvalues.size());
for (int i = 0; i < featureXSvalues.size(); i++) {
// initialize an empty feature to use in the generated PDP
if (feature == null) {
feature = FeatureFactory.copyOf(featureXSvalues.get(i), new Value(null));
}
List<PredictionInput> predictionInputs = prepareInputs(featureXSvalues.get(i), trainingData);
List<PredictionOutput> predictionOutputs = getOutputs(model, predictionInputs);
// prediction requests are batched per value of feature 'Xs' under analysis
for (PredictionOutput predictionOutput : predictionOutputs) {
Output output = predictionOutput.getOutputs().get(outputIndex);
if (outputDecision == null) {
outputDecision = new Output(output.getName(), output.getType());
}
// update output value counts
updateValueCounts(valueCounts, i, output);
}
}
if (outputDecision != null) {
List<Value> yValues = collapseMarginalImpacts(valueCounts, outputDecision.getType());
return new PartialDependenceGraph(feature, outputDecision, xsValues, yValues);
} else {
throw new IllegalArgumentException("cannot produce PDP for null decision");
}
}
use of org.kie.kogito.explainability.model.PartialDependenceGraph in project kogito-apps by kiegroup.
the class PartialDependencePlotExplainer method explainFromDataDistribution.
private List<PartialDependenceGraph> explainFromDataDistribution(PredictionProvider model, int outputSize, DataDistribution dataDistribution) throws InterruptedException, ExecutionException, TimeoutException {
long start = System.currentTimeMillis();
List<PartialDependenceGraph> pdps = new ArrayList<>();
List<FeatureDistribution> featureDistributions = dataDistribution.asFeatureDistributions();
// fetch entire data distributions for all features
List<PredictionInput> trainingData = dataDistribution.sample(config.getSeriesLength());
// create a PDP for each feature
for (FeatureDistribution featureDistribution : featureDistributions) {
// generate (further) samples for the feature under analysis
// TBD: maybe just reuse trainingData
List<Value> xsValues = featureDistribution.sample(config.getSeriesLength()).stream().sorted(// sort alphanumerically (if Value#asNumber is NaN)
Comparator.comparing(Value::asString)).sorted(// sort by natural order
(v1, v2) -> Comparator.comparingDouble(Value::asNumber).compare(v1, v2)).distinct().collect(Collectors.toList());
List<Feature> featureXSvalues = // transform sampled Values into Features
xsValues.stream().map(v -> FeatureFactory.copyOf(featureDistribution.getFeature(), v)).collect(Collectors.toList());
// create a PDP for each feature and each output
for (int outputIndex = 0; outputIndex < outputSize; outputIndex++) {
PartialDependenceGraph partialDependenceGraph = getPartialDependenceGraph(model, trainingData, xsValues, featureXSvalues, outputIndex);
pdps.add(partialDependenceGraph);
}
}
long end = System.currentTimeMillis();
LOGGER.debug("explanation time: {}ms", (end - start));
return pdps;
}
use of org.kie.kogito.explainability.model.PartialDependenceGraph in project kogito-apps by kiegroup.
the class TrafficViolationDmnPDPExplainerTest method testTrafficViolationDMNExplanation.
@Test
void testTrafficViolationDMNExplanation() throws ExecutionException, InterruptedException, TimeoutException {
DMNRuntime dmnRuntime = DMNKogito.createGenericDMNRuntime(new InputStreamReader(getClass().getResourceAsStream("/dmn/TrafficViolation.dmn")));
assertEquals(1, dmnRuntime.getModels().size());
final String TRAFFIC_VIOLATION_NS = "https://github.com/kiegroup/drools/kie-dmn/_A4BCA8B8-CF08-433F-93B2-A2598F19ECFF";
final String TRAFFIC_VIOLATION_NAME = "Traffic Violation";
DecisionModel decisionModel = new DmnDecisionModel(dmnRuntime, TRAFFIC_VIOLATION_NS, TRAFFIC_VIOLATION_NAME);
PredictionProvider model = new DecisionModelWrapper(decisionModel);
List<PredictionInput> inputs = DmnTestUtils.randomTrafficViolationInputs();
List<PredictionOutput> predictionOutputs = model.predictAsync(inputs).get(Config.INSTANCE.getAsyncTimeout(), Config.INSTANCE.getAsyncTimeUnit());
List<Prediction> predictions = new ArrayList<>();
for (int i = 0; i < predictionOutputs.size(); i++) {
predictions.add(new SimplePrediction(inputs.get(i), predictionOutputs.get(i)));
}
PartialDependencePlotExplainer partialDependencePlotExplainer = new PartialDependencePlotExplainer();
List<PartialDependenceGraph> pdps = partialDependencePlotExplainer.explainFromPredictions(model, predictions);
AssertionsForClassTypes.assertThat(pdps).isNotNull();
Assertions.assertThat(pdps).hasSize(8);
}
use of org.kie.kogito.explainability.model.PartialDependenceGraph in project kogito-apps by kiegroup.
the class PrequalificationDmnPDPExplainerTest method testPrequalificationDMNExplanation.
@Test
void testPrequalificationDMNExplanation() throws ExecutionException, InterruptedException, TimeoutException {
DMNRuntime dmnRuntime = DMNKogito.createGenericDMNRuntime(new InputStreamReader(getClass().getResourceAsStream("/dmn/Prequalification-1.dmn")));
assertEquals(1, dmnRuntime.getModels().size());
final String NS = "http://www.trisotech.com/definitions/_f31e1f8e-d4ce-4a3a-ac3b-747efa6b3401";
final String NAME = "Prequalification";
DecisionModel decisionModel = new DmnDecisionModel(dmnRuntime, NS, NAME);
PredictionProvider model = new DecisionModelWrapper(decisionModel);
List<PredictionInput> inputs = DmnTestUtils.randomPrequalificationInputs();
List<PredictionOutput> predictionOutputs = model.predictAsync(inputs).get(Config.INSTANCE.getAsyncTimeout(), Config.INSTANCE.getAsyncTimeUnit());
List<Prediction> predictions = new ArrayList<>();
for (int i = 0; i < predictionOutputs.size(); i++) {
predictions.add(new SimplePrediction(inputs.get(i), predictionOutputs.get(i)));
}
PartialDependencePlotExplainer partialDependencePlotExplainer = new PartialDependencePlotExplainer();
List<PartialDependenceGraph> pdps = partialDependencePlotExplainer.explainFromPredictions(model, predictions);
AssertionsForClassTypes.assertThat(pdps).isNotNull();
Assertions.assertThat(pdps).hasSize(25);
}
Aggregations