use of org.kie.kogito.explainability.model.Value in project kogito-apps by kiegroup.
the class PartialDependencePlotExplainer method explainFromDataDistribution.
private List<PartialDependenceGraph> explainFromDataDistribution(PredictionProvider model, int outputSize, DataDistribution dataDistribution) throws InterruptedException, ExecutionException, TimeoutException {
long start = System.currentTimeMillis();
List<PartialDependenceGraph> pdps = new ArrayList<>();
List<FeatureDistribution> featureDistributions = dataDistribution.asFeatureDistributions();
// fetch entire data distributions for all features
List<PredictionInput> trainingData = dataDistribution.sample(config.getSeriesLength());
// create a PDP for each feature
for (FeatureDistribution featureDistribution : featureDistributions) {
// generate (further) samples for the feature under analysis
// TBD: maybe just reuse trainingData
List<Value> xsValues = featureDistribution.sample(config.getSeriesLength()).stream().sorted(// sort alphanumerically (if Value#asNumber is NaN)
Comparator.comparing(Value::asString)).sorted(// sort by natural order
(v1, v2) -> Comparator.comparingDouble(Value::asNumber).compare(v1, v2)).distinct().collect(Collectors.toList());
List<Feature> featureXSvalues = // transform sampled Values into Features
xsValues.stream().map(v -> FeatureFactory.copyOf(featureDistribution.getFeature(), v)).collect(Collectors.toList());
// create a PDP for each feature and each output
for (int outputIndex = 0; outputIndex < outputSize; outputIndex++) {
PartialDependenceGraph partialDependenceGraph = getPartialDependenceGraph(model, trainingData, xsValues, featureXSvalues, outputIndex);
pdps.add(partialDependenceGraph);
}
}
long end = System.currentTimeMillis();
LOGGER.debug("explanation time: {}ms", (end - start));
return pdps;
}
use of org.kie.kogito.explainability.model.Value in project kogito-apps by kiegroup.
the class PartialDependencePlotExplainer method collapseMarginalImpacts.
/**
* Collapse value counts into marginal impacts.
* For numbers ({@code Type.NUMBER.equals(type))} this is just the average of each value at each feature value.
* For all other types the final {@link Value} is just the most frequent.
*
* @param valueCounts the frequency of each value at each position
* @param type the type of the output
* @return the marginal impacts
*/
private List<Value> collapseMarginalImpacts(List<Map<Value, Long>> valueCounts, Type type) {
List<Value> yValues = new ArrayList<>();
if (Type.NUMBER.equals(type)) {
List<Double> doubles = valueCounts.stream().map(v -> v.entrySet().stream().map(e -> e.getKey().asNumber() * e.getValue() / config.getSeriesLength()).mapToDouble(d -> d).sum()).collect(Collectors.toList());
yValues = doubles.stream().map(Value::new).collect(Collectors.toList());
} else {
for (Map<Value, Long> item : valueCounts) {
long max = 0;
String output = null;
for (Map.Entry<Value, Long> entry : item.entrySet()) {
if (entry.getValue() > max) {
max = entry.getValue();
output = entry.getKey().asString();
}
}
yValues.add(new Value(output));
}
}
return yValues;
}
use of org.kie.kogito.explainability.model.Value in project kogito-apps by kiegroup.
the class PmmlRegressionCategoricalLimeExplainerTest method getModel.
private PredictionProvider getModel() {
return inputs -> CompletableFuture.supplyAsync(() -> {
List<PredictionOutput> outputs = new ArrayList<>();
for (PredictionInput input1 : inputs) {
List<Feature> features1 = input1.getFeatures();
CategoricalVariablesRegressionExecutor pmmlModel = new CategoricalVariablesRegressionExecutor(features1.get(0).getValue().asString(), features1.get(1).getValue().asString());
PMML4Result result = pmmlModel.execute(categoricalVariableRegressionRuntime);
String score = result.getResultVariables().get("result").toString();
PredictionOutput predictionOutput = new PredictionOutput(List.of(new Output("result", Type.NUMBER, new Value(score), 1d)));
outputs.add(predictionOutput);
}
return outputs;
});
}
use of org.kie.kogito.explainability.model.Value in project kogito-apps by kiegroup.
the class RemotePredictionProvider method toPredictionOutput.
protected PredictionOutput toPredictionOutput(JsonObject mainObj) {
if (mainObj == null || !mainObj.containsKey("result")) {
LOG.error("Malformed json {}", mainObj);
return null;
}
List<Output> resultOutputs = toOutputList(mainObj.getJsonObject("result"));
List<String> resultOutputNames = resultOutputs.stream().map(Output::getName).collect(toList());
Map<String, TypedValue> mappedOutputs = predictionOutputs.stream().collect(Collectors.toMap(HasNameValue::getName, HasNameValue::getValue));
// It's possible that some outputs are missing in the response from the prediction service
// (e.g. when the generated perturbed inputs don't make sense and a decision is skipped).
// The explainer, however, may throw exceptions if it can't find all the inputs that were
// specified in the execution request.
// Here we take the outputs received from the prediction service and we fill (only if needed)
// the missing ones with Output objects containing "null" values of type UNDEFINED, to make
// the explainer happy.
List<Output> outputs = Stream.concat(resultOutputs.stream().filter(output -> mappedOutputs.containsKey(output.getName())), mappedOutputs.keySet().stream().filter(key -> !resultOutputNames.contains(key)).map(key -> new Output(key, Type.UNDEFINED, new Value(null), 1d))).collect(toList());
return new PredictionOutput(outputs);
}
use of org.kie.kogito.explainability.model.Value in project kogito-apps by kiegroup.
the class PmmlRegressionLimeExplainerTest method getModel.
private PredictionProvider getModel() {
return inputs -> CompletableFuture.supplyAsync(() -> {
List<PredictionOutput> outputs = new ArrayList<>();
for (PredictionInput input1 : inputs) {
List<Feature> features1 = input1.getFeatures();
LogisticRegressionIrisDataExecutor pmmlModel = new LogisticRegressionIrisDataExecutor(features1.get(0).getValue().asNumber(), features1.get(1).getValue().asNumber(), features1.get(2).getValue().asNumber(), features1.get(3).getValue().asNumber());
PMML4Result result = pmmlModel.execute(logisticRegressionIrisRuntime);
String species = result.getResultVariables().get("Species").toString();
double score = Double.parseDouble(result.getResultVariables().get("Probability_" + species).toString());
PredictionOutput predictionOutput = new PredictionOutput(List.of(new Output("species", Type.TEXT, new Value(species), score)));
outputs.add(predictionOutput);
}
return outputs;
});
}
Aggregations