Search in sources :

Example 1 with WeightedLinearRegressionResults

use of org.kie.kogito.explainability.utils.WeightedLinearRegressionResults in project kogito-apps by kiegroup.

the class ShapKernelExplainer method runWLRR.

/**
 * Run the WLR model over the expectations.
 *
 * @param maskDiff: The mask matrix, not including the regularization feature
 * @param adjY: The expected model outputs, adjusted for dropping the regularization feature
 * @param ws: The weights of each sample
 * @param outputChange: The raw difference between the model output and the null output
 * @param dropIdx: The regularization feature index
 *
 * @return a 2xnFeatures array, containing the shap values as found by the WLR in the first row and the
 *         confidences of those values in the second row.
 */
// run the WLRR for a single output
private RealVector[] runWLRR(RealMatrix maskDiff, RealVector adjY, RealVector ws, double outputChange, int dropIdx, List<Integer> nonzeros, ShapDataCarrier sdc) {
    // temporary conversion to and from MAtrixUtils data structures; these will be used throughout after FAI-661
    WeightedLinearRegressionResults wlrr = WeightedLinearRegression.fit(maskDiff, adjY, ws, false);
    RealVector coeffs = wlrr.getCoefficients();
    RealVector bounds = wlrr.getConf(1 - this.config.getConfidence());
    int usedCoefs = 0;
    RealVector shapSlice = MatrixUtils.createRealVector(new double[sdc.getCols()]);
    RealVector boundsReg = shapSlice.copy();
    for (int idx : nonzeros) {
        if (idx != dropIdx) {
            shapSlice.setEntry(idx, coeffs.getEntry(usedCoefs));
            boundsReg.setEntry(idx, bounds.getEntry(usedCoefs));
            usedCoefs += 1;
        }
    }
    shapSlice.setEntry(dropIdx, outputChange - MatrixUtilsExtensions.sum(coeffs));
    // propagate the error of sum
    boundsReg.setEntry(dropIdx, Math.sqrt(MatrixUtilsExtensions.sum(bounds.map(x -> x * x))));
    // bundle error and shap values together
    RealVector[] wlrrOutput = new RealVector[2];
    wlrrOutput[0] = shapSlice;
    wlrrOutput[1] = boundsReg;
    return wlrrOutput;
}
Also used : IntStream(java.util.stream.IntStream) Arrays(java.util.Arrays) LarsPath(org.kie.kogito.explainability.utils.LarsPath) Prediction(org.kie.kogito.explainability.model.Prediction) LoggerFactory(org.slf4j.LoggerFactory) HashMap(java.util.HashMap) CompletableFuture(java.util.concurrent.CompletableFuture) RealVector(org.apache.commons.math3.linear.RealVector) WeightedLinearRegression(org.kie.kogito.explainability.utils.WeightedLinearRegression) Saliency(org.kie.kogito.explainability.model.Saliency) ArrayList(java.util.ArrayList) MathArithmeticException(org.apache.commons.math3.exception.MathArithmeticException) MatrixUtils(org.apache.commons.math3.linear.MatrixUtils) PredictionOutput(org.kie.kogito.explainability.model.PredictionOutput) LassoLarsIC(org.kie.kogito.explainability.utils.LassoLarsIC) CombinatoricsUtils(org.apache.commons.math3.util.CombinatoricsUtils) Logger(org.slf4j.Logger) Iterator(java.util.Iterator) LocalExplainer(org.kie.kogito.explainability.local.LocalExplainer) AnyMatrix(org.apache.commons.math3.linear.AnyMatrix) WeightedLinearRegressionResults(org.kie.kogito.explainability.utils.WeightedLinearRegressionResults) FeatureImportance(org.kie.kogito.explainability.model.FeatureImportance) Collectors(java.util.stream.Collectors) PredictionProvider(org.kie.kogito.explainability.model.PredictionProvider) Consumer(java.util.function.Consumer) PredictionInput(org.kie.kogito.explainability.model.PredictionInput) List(java.util.List) MatrixUtilsExtensions(org.kie.kogito.explainability.utils.MatrixUtilsExtensions) RandomChoice(org.kie.kogito.explainability.utils.RandomChoice) RealMatrix(org.apache.commons.math3.linear.RealMatrix) Collections(java.util.Collections) RealVector(org.apache.commons.math3.linear.RealVector) WeightedLinearRegressionResults(org.kie.kogito.explainability.utils.WeightedLinearRegressionResults)

Aggregations

ArrayList (java.util.ArrayList)1 Arrays (java.util.Arrays)1 Collections (java.util.Collections)1 HashMap (java.util.HashMap)1 Iterator (java.util.Iterator)1 List (java.util.List)1 CompletableFuture (java.util.concurrent.CompletableFuture)1 Consumer (java.util.function.Consumer)1 Collectors (java.util.stream.Collectors)1 IntStream (java.util.stream.IntStream)1 MathArithmeticException (org.apache.commons.math3.exception.MathArithmeticException)1 AnyMatrix (org.apache.commons.math3.linear.AnyMatrix)1 MatrixUtils (org.apache.commons.math3.linear.MatrixUtils)1 RealMatrix (org.apache.commons.math3.linear.RealMatrix)1 RealVector (org.apache.commons.math3.linear.RealVector)1 CombinatoricsUtils (org.apache.commons.math3.util.CombinatoricsUtils)1 LocalExplainer (org.kie.kogito.explainability.local.LocalExplainer)1 FeatureImportance (org.kie.kogito.explainability.model.FeatureImportance)1 Prediction (org.kie.kogito.explainability.model.Prediction)1 PredictionInput (org.kie.kogito.explainability.model.PredictionInput)1