use of org.kie.kogito.explainability.utils.WeightedLinearRegressionResults in project kogito-apps by kiegroup.
the class ShapKernelExplainer method runWLRR.
/**
* Run the WLR model over the expectations.
*
* @param maskDiff: The mask matrix, not including the regularization feature
* @param adjY: The expected model outputs, adjusted for dropping the regularization feature
* @param ws: The weights of each sample
* @param outputChange: The raw difference between the model output and the null output
* @param dropIdx: The regularization feature index
*
* @return a 2xnFeatures array, containing the shap values as found by the WLR in the first row and the
* confidences of those values in the second row.
*/
// run the WLRR for a single output
private RealVector[] runWLRR(RealMatrix maskDiff, RealVector adjY, RealVector ws, double outputChange, int dropIdx, List<Integer> nonzeros, ShapDataCarrier sdc) {
// temporary conversion to and from MAtrixUtils data structures; these will be used throughout after FAI-661
WeightedLinearRegressionResults wlrr = WeightedLinearRegression.fit(maskDiff, adjY, ws, false);
RealVector coeffs = wlrr.getCoefficients();
RealVector bounds = wlrr.getConf(1 - this.config.getConfidence());
int usedCoefs = 0;
RealVector shapSlice = MatrixUtils.createRealVector(new double[sdc.getCols()]);
RealVector boundsReg = shapSlice.copy();
for (int idx : nonzeros) {
if (idx != dropIdx) {
shapSlice.setEntry(idx, coeffs.getEntry(usedCoefs));
boundsReg.setEntry(idx, bounds.getEntry(usedCoefs));
usedCoefs += 1;
}
}
shapSlice.setEntry(dropIdx, outputChange - MatrixUtilsExtensions.sum(coeffs));
// propagate the error of sum
boundsReg.setEntry(dropIdx, Math.sqrt(MatrixUtilsExtensions.sum(bounds.map(x -> x * x))));
// bundle error and shap values together
RealVector[] wlrrOutput = new RealVector[2];
wlrrOutput[0] = shapSlice;
wlrrOutput[1] = boundsReg;
return wlrrOutput;
}
Aggregations