use of org.kie.kogito.explainability.utils.LinearModel in project kogito-apps by kiegroup.
the class LimeExplainer method getSaliency.
private void getSaliency(List<Feature> linearizedTargetInputFeatures, Map<String, Saliency> result, LimeInputs limeInputs, Output originalOutput, LimeConfig executionConfig) {
List<FeatureImportance> featureImportanceList = new ArrayList<>();
// encode the training data so that it can be fed into the linear model
DatasetEncoder datasetEncoder = new DatasetEncoder(limeInputs.getPerturbedInputs(), limeInputs.getPerturbedOutputs(), linearizedTargetInputFeatures, originalOutput, executionConfig.getEncodingParams());
List<Pair<double[], Double>> trainingSet = datasetEncoder.getEncodedTrainingSet();
// weight the training samples based on the proximity to the target input to explain
double kernelWidth = executionConfig.getProximityKernelWidth() * Math.sqrt(linearizedTargetInputFeatures.size());
double[] sampleWeights = SampleWeighter.getSampleWeights(linearizedTargetInputFeatures, trainingSet, kernelWidth);
int ts = linearizedTargetInputFeatures.size();
double[] featureWeights = new double[ts];
Arrays.fill(featureWeights, 1);
if (executionConfig.isPenalizeBalanceSparse()) {
IndependentSparseFeatureBalanceFilter sparseFeatureBalanceFilter = new IndependentSparseFeatureBalanceFilter();
sparseFeatureBalanceFilter.apply(featureWeights, linearizedTargetInputFeatures, trainingSet);
}
if (executionConfig.isProximityFilter()) {
ProximityFilter proximityFilter = new ProximityFilter(executionConfig.getProximityThreshold(), executionConfig.getProximityFilteredDatasetMinimum().doubleValue());
proximityFilter.apply(trainingSet, sampleWeights);
}
LinearModel linearModel = new LinearModel(linearizedTargetInputFeatures.size(), limeInputs.isClassification());
double loss = linearModel.fit(trainingSet, sampleWeights);
if (!Double.isNaN(loss)) {
// create the output saliency
double[] weights = linearModel.getWeights();
if (limeConfig.isNormalizeWeights() && weights.length > 0) {
normalizeWeights(weights);
}
int i = 0;
for (Feature linearizedFeature : linearizedTargetInputFeatures) {
FeatureImportance featureImportance = new FeatureImportance(linearizedFeature, weights[i] * featureWeights[i]);
featureImportanceList.add(featureImportance);
i++;
}
}
Saliency saliency = new Saliency(originalOutput, featureImportanceList);
result.put(originalOutput.getName(), saliency);
}
Aggregations