use of org.kie.kogito.explainability.model.PredictionProvider in project kogito-apps by kiegroup.
the class CountingOptimizationStrategyTest method testMaybeOptimize.
@Test
void testMaybeOptimize() {
LimeOptimizationService optimizationService = mock(LimeOptimizationService.class);
CountingOptimizationStrategy strategy = new CountingOptimizationStrategy(10, optimizationService);
List<Prediction> recordedPredictions = Collections.emptyList();
PredictionProvider model = mock(PredictionProvider.class);
LimeExplainer explaier = new LimeExplainer();
LimeConfig config = new LimeConfig();
assertThatCode(() -> strategy.maybeOptimize(recordedPredictions, model, explaier, config)).doesNotThrowAnyException();
}
use of org.kie.kogito.explainability.model.PredictionProvider in project kogito-apps by kiegroup.
the class TestUtils method getFixedOutputClassifier.
public static PredictionProvider getFixedOutputClassifier() {
return inputs -> supplyAsync(() -> {
List<PredictionOutput> outputs = new LinkedList<>();
for (PredictionInput ignored : inputs) {
Output output = new Output("class", Type.BOOLEAN, new Value(false), 1d);
outputs.add(new PredictionOutput(List.of(output)));
}
return outputs;
});
}
use of org.kie.kogito.explainability.model.PredictionProvider in project kogito-apps by kiegroup.
the class TestUtils method getSumSkipModel.
public static PredictionProvider getSumSkipModel(int skipFeatureIndex) {
return inputs -> supplyAsync(() -> {
List<PredictionOutput> predictionOutputs = new LinkedList<>();
for (PredictionInput predictionInput : inputs) {
List<Feature> features = predictionInput.getFeatures();
double result = 0;
for (int i = 0; i < features.size(); i++) {
if (skipFeatureIndex != i) {
result += features.get(i).getValue().asNumber();
}
}
PredictionOutput predictionOutput = new PredictionOutput(List.of(new Output("sum-but" + skipFeatureIndex, Type.NUMBER, new Value(result), 1d)));
predictionOutputs.add(predictionOutput);
}
return predictionOutputs;
});
}
use of org.kie.kogito.explainability.model.PredictionProvider in project kogito-apps by kiegroup.
the class TestUtils method getSumSkipTwoOutputModel.
public static PredictionProvider getSumSkipTwoOutputModel(int skipFeatureIndex) {
return inputs -> supplyAsync(() -> {
List<PredictionOutput> predictionOutputs = new LinkedList<>();
for (PredictionInput predictionInput : inputs) {
List<Feature> features = predictionInput.getFeatures();
double result = 0;
for (int i = 0; i < features.size(); i++) {
if (skipFeatureIndex != i) {
result += features.get(i).getValue().asNumber();
}
}
Output output0 = new Output("sum-but" + skipFeatureIndex, Type.NUMBER, new Value(result), 1d);
Output output1 = new Output("sum-but" + skipFeatureIndex + "*2", Type.NUMBER, new Value(result * 2), 1d);
PredictionOutput predictionOutput = new PredictionOutput(List.of(output0, output1));
predictionOutputs.add(predictionOutput);
}
return predictionOutputs;
});
}
use of org.kie.kogito.explainability.model.PredictionProvider in project kogito-apps by kiegroup.
the class TestUtils method getLinearModel.
public static PredictionProvider getLinearModel(double[] weights) {
return inputs -> supplyAsync(() -> {
List<PredictionOutput> predictionOutputs = new LinkedList<>();
for (PredictionInput predictionInput : inputs) {
List<Feature> features = predictionInput.getFeatures();
double result = 0;
for (int i = 0; i < features.size(); i++) {
result += features.get(i).getValue().asNumber() * weights[i];
}
PredictionOutput predictionOutput = new PredictionOutput(List.of(new Output("linear-sum", Type.NUMBER, new Value(result), 1d)));
predictionOutputs.add(predictionOutput);
}
return predictionOutputs;
});
}
Aggregations