Search in sources :

Example 1 with LinearModel

use of org.kie.kogito.explainability.utils.LinearModel in project kogito-apps by kiegroup.

the class LimeExplainer method getSaliency.

private void getSaliency(List<Feature> linearizedTargetInputFeatures, Map<String, Saliency> result, LimeInputs limeInputs, Output originalOutput, LimeConfig executionConfig) {
    List<FeatureImportance> featureImportanceList = new ArrayList<>();
    // encode the training data so that it can be fed into the linear model
    DatasetEncoder datasetEncoder = new DatasetEncoder(limeInputs.getPerturbedInputs(), limeInputs.getPerturbedOutputs(), linearizedTargetInputFeatures, originalOutput, executionConfig.getEncodingParams());
    List<Pair<double[], Double>> trainingSet = datasetEncoder.getEncodedTrainingSet();
    // weight the training samples based on the proximity to the target input to explain
    double kernelWidth = executionConfig.getProximityKernelWidth() * Math.sqrt(linearizedTargetInputFeatures.size());
    double[] sampleWeights = SampleWeighter.getSampleWeights(linearizedTargetInputFeatures, trainingSet, kernelWidth);
    int ts = linearizedTargetInputFeatures.size();
    double[] featureWeights = new double[ts];
    Arrays.fill(featureWeights, 1);
    if (executionConfig.isPenalizeBalanceSparse()) {
        IndependentSparseFeatureBalanceFilter sparseFeatureBalanceFilter = new IndependentSparseFeatureBalanceFilter();
        sparseFeatureBalanceFilter.apply(featureWeights, linearizedTargetInputFeatures, trainingSet);
    }
    if (executionConfig.isProximityFilter()) {
        ProximityFilter proximityFilter = new ProximityFilter(executionConfig.getProximityThreshold(), executionConfig.getProximityFilteredDatasetMinimum().doubleValue());
        proximityFilter.apply(trainingSet, sampleWeights);
    }
    LinearModel linearModel = new LinearModel(linearizedTargetInputFeatures.size(), limeInputs.isClassification());
    double loss = linearModel.fit(trainingSet, sampleWeights);
    if (!Double.isNaN(loss)) {
        // create the output saliency
        double[] weights = linearModel.getWeights();
        if (limeConfig.isNormalizeWeights() && weights.length > 0) {
            normalizeWeights(weights);
        }
        int i = 0;
        for (Feature linearizedFeature : linearizedTargetInputFeatures) {
            FeatureImportance featureImportance = new FeatureImportance(linearizedFeature, weights[i] * featureWeights[i]);
            featureImportanceList.add(featureImportance);
            i++;
        }
    }
    Saliency saliency = new Saliency(originalOutput, featureImportanceList);
    result.put(originalOutput.getName(), saliency);
}
Also used : ArrayList(java.util.ArrayList) Saliency(org.kie.kogito.explainability.model.Saliency) Feature(org.kie.kogito.explainability.model.Feature) FeatureImportance(org.kie.kogito.explainability.model.FeatureImportance) LinearModel(org.kie.kogito.explainability.utils.LinearModel) Pair(org.apache.commons.lang3.tuple.Pair)

Aggregations

ArrayList (java.util.ArrayList)1 Pair (org.apache.commons.lang3.tuple.Pair)1 Feature (org.kie.kogito.explainability.model.Feature)1 FeatureImportance (org.kie.kogito.explainability.model.FeatureImportance)1 Saliency (org.kie.kogito.explainability.model.Saliency)1 LinearModel (org.kie.kogito.explainability.utils.LinearModel)1