use of org.kie.kogito.explainability.model.PredictionInput in project kogito-apps by kiegroup.
the class LimeExplainer method prepareInputs.
/**
* Check the perturbed inputs so that the dataset of perturbed input / outputs contains more than just one output
* class, otherwise it would be impossible to linearly separate it, and hence learn meaningful weights to be used as
* feature importance scores.
* The check can be {@code strict} or not, if so it will throw a {@code DatasetNotSeparableException} when the dataset
* for a given output is not separable.
*/
private LimeInputs prepareInputs(List<PredictionInput> perturbedInputs, List<PredictionOutput> perturbedOutputs, List<Feature> linearizedTargetInputFeatures, int o, Output currentOutput, boolean strict) {
if (currentOutput.getValue() != null && currentOutput.getValue().getUnderlyingObject() != null) {
Map<Double, Long> rawClassesBalance;
// calculate the no. of samples belonging to each output class
Value fv = currentOutput.getValue();
rawClassesBalance = getClassBalance(perturbedOutputs, fv, o);
Long max = rawClassesBalance.values().stream().max(Long::compareTo).orElse(1L);
double separationRatio = (double) max / (double) perturbedInputs.size();
List<Output> outputs = perturbedOutputs.stream().map(po -> po.getOutputs().get(o)).collect(Collectors.toList());
boolean classification = rawClassesBalance.size() == 2;
if (strict) {
// check if the dataset is separable and also if the linear model should fit a regressor or a classifier
if (rawClassesBalance.size() > 1 && separationRatio < limeConfig.getSeparableDatasetRatio()) {
// if dataset creation process succeeds use it to train the linear model
return new LimeInputs(classification, linearizedTargetInputFeatures, currentOutput, perturbedInputs, outputs);
} else {
throw new DatasetNotSeparableException(currentOutput, rawClassesBalance);
}
} else {
LOGGER.warn("Using an hardly separable dataset for output '{}' of type '{}' with value '{}' ({})", currentOutput.getName(), currentOutput.getType(), currentOutput.getValue(), rawClassesBalance);
return new LimeInputs(classification, linearizedTargetInputFeatures, currentOutput, perturbedInputs, outputs);
}
} else {
return new LimeInputs(false, linearizedTargetInputFeatures, currentOutput, emptyList(), emptyList());
}
}
use of org.kie.kogito.explainability.model.PredictionInput in project kogito-apps by kiegroup.
the class LimeExplainer method explainAsync.
@Override
public CompletableFuture<Map<String, Saliency>> explainAsync(Prediction prediction, PredictionProvider model, Consumer<Map<String, Saliency>> intermediateResultsConsumer) {
PredictionInput originalInput = prediction.getInput();
if (originalInput == null || originalInput.getFeatures() == null || (originalInput.getFeatures() != null && originalInput.getFeatures().isEmpty())) {
throw new LocalExplanationException("cannot explain a prediction whose input is empty");
}
List<PredictionInput> linearizedInputs = DataUtils.linearizeInputs(List.of(originalInput));
PredictionInput targetInput = linearizedInputs.get(0);
List<Feature> linearizedTargetInputFeatures = targetInput.getFeatures();
if (linearizedTargetInputFeatures.isEmpty()) {
throw new LocalExplanationException("input features linearization failed");
}
List<Output> actualOutputs = prediction.getOutput().getOutputs();
LimeConfig executionConfig = limeConfig.copy();
return explainWithExecutionConfig(model, originalInput, linearizedTargetInputFeatures, actualOutputs, executionConfig);
}
use of org.kie.kogito.explainability.model.PredictionInput in project kogito-apps by kiegroup.
the class LocalDMNPredictionProvider method predictAsync.
@Override
@SuppressWarnings("unchecked")
public CompletableFuture<List<PredictionOutput>> predictAsync(List<PredictionInput> inputs) {
List<PredictionOutput> predictionOutputs = new ArrayList<>();
for (PredictionInput input : inputs) {
Map<String, Object> contextVariables = (Map<String, Object>) toMap(input.getFeatures()).get(DUMMY_DMN_CONTEXT_KEY);
predictionOutputs.add(toPredictionOutput(dmnEvaluator.evaluate(contextVariables)));
}
return completedFuture(predictionOutputs);
}
use of org.kie.kogito.explainability.model.PredictionInput in project kogito-apps by kiegroup.
the class CounterfactualExplainerServiceHandler method getPrediction.
@Override
public Prediction getPrediction(CounterfactualExplainabilityRequest request) {
Collection<NamedTypedValue> goals = toMapBasedSorting(request.getGoals());
Collection<CounterfactualSearchDomain> searchDomains = request.getSearchDomains();
Collection<NamedTypedValue> originalInputs = request.getOriginalInputs();
Long maxRunningTimeSeconds = request.getMaxRunningTimeSeconds();
if (Objects.nonNull(maxRunningTimeSeconds)) {
if (maxRunningTimeSeconds > kafkaMaxRecordAgeSeconds) {
LOGGER.info(String.format("Maximum Running Timeout set to '%d's since the provided value '%d's exceeded the Messaging sub-system configuration '%d's.", kafkaMaxRecordAgeSeconds, maxRunningTimeSeconds, kafkaMaxRecordAgeSeconds));
maxRunningTimeSeconds = kafkaMaxRecordAgeSeconds;
}
}
// See https://issues.redhat.com/browse/FAI-473 and https://issues.redhat.com/browse/FAI-474
if (isUnsupportedModel(originalInputs, goals, searchDomains)) {
throw new IllegalArgumentException("Counterfactual explanations only support flat models.");
}
PredictionInput input = new PredictionInput(toFeatureList(originalInputs, searchDomains));
PredictionOutput output = new PredictionOutput(toOutputList(goals));
return new CounterfactualPrediction(input, output, null, UUID.fromString(request.getExecutionId()), maxRunningTimeSeconds);
}
use of org.kie.kogito.explainability.model.PredictionInput in project kogito-apps by kiegroup.
the class LimeExplainerServiceHandler method getPrediction.
@Override
public Prediction getPrediction(LIMEExplainabilityRequest request) {
Collection<NamedTypedValue> inputs = request.getInputs();
Collection<NamedTypedValue> outputs = request.getOutputs();
PredictionInput input = new PredictionInput(toFeatureList(inputs));
PredictionOutput output = new PredictionOutput(toOutputList(outputs));
return new SimplePrediction(input, output);
}
Aggregations