use of org.kie.kogito.explainability.api.NamedTypedValue in project kogito-apps by kiegroup.
the class ExplanationServiceImplTest method testCounterfactualsExplainAsyncSuccess.
@SuppressWarnings("unchecked")
void testCounterfactualsExplainAsyncSuccess(ThrowingSupplier<BaseExplainabilityResult> invocation) {
when(instance.stream()).thenReturn(Stream.of(cfExplainerServiceHandlerMock));
when(cfExplainerMock.explainAsync(any(Prediction.class), eq(predictionProviderMock), any(Consumer.class))).thenReturn(CompletableFuture.completedFuture(COUNTERFACTUAL_RESULT));
BaseExplainabilityResult result = assertDoesNotThrow(invocation);
assertNotNull(result);
assertTrue(result instanceof CounterfactualExplainabilityResult);
CounterfactualExplainabilityResult counterfactualResult = (CounterfactualExplainabilityResult) result;
assertEquals(EXECUTION_ID, counterfactualResult.getExecutionId());
assertEquals(COUNTERFACTUAL_ID, counterfactualResult.getCounterfactualId());
assertSame(ExplainabilityStatus.SUCCEEDED, counterfactualResult.getStatus());
assertNull(counterfactualResult.getStatusDetails());
assertEquals(COUNTERFACTUAL_RESULT.getEntities().size(), counterfactualResult.getInputs().size());
assertEquals(COUNTERFACTUAL_RESULT.getOutput().size(), counterfactualResult.getOutputs().size());
assertTrue(counterfactualResult.getOutputs().stream().anyMatch(o -> o.getName().equals("output1")));
NamedTypedValue value = counterfactualResult.getOutputs().iterator().next();
assertTrue(value.getValue().isUnit());
assertEquals(Double.class.getSimpleName(), value.getValue().toUnit().getType());
assertEquals(555.0, value.getValue().toUnit().getValue().asDouble());
}
use of org.kie.kogito.explainability.api.NamedTypedValue in project kogito-apps by kiegroup.
the class CounterfactualExplainerServiceHandlerTest method testGetPredictionWithStructuredOutputModel.
@Test
public void testGetPredictionWithStructuredOutputModel() {
CounterfactualExplainabilityRequest request = new CounterfactualExplainabilityRequest(EXECUTION_ID, SERVICE_URL, MODEL_IDENTIFIER, COUNTERFACTUAL_ID, Collections.emptyList(), List.of(new NamedTypedValue("input1", new StructureValue("number", Map.of("input2b", new UnitValue("number", new IntNode(55)))))), Collections.emptyList(), MAX_RUNNING_TIME_SECONDS);
assertThrows(IllegalArgumentException.class, () -> handler.getPrediction(request));
}
use of org.kie.kogito.explainability.api.NamedTypedValue in project kogito-apps by kiegroup.
the class CounterfactualExplainerServiceHandlerTest method testGetPredictionWithFlatInputModel.
@Test
public void testGetPredictionWithFlatInputModel() {
CounterfactualExplainabilityRequest request = new CounterfactualExplainabilityRequest(EXECUTION_ID, SERVICE_URL, MODEL_IDENTIFIER, COUNTERFACTUAL_ID, List.of(new NamedTypedValue("input1", new UnitValue("number", new IntNode(20)))), Collections.emptyList(), Collections.emptyList(), MAX_RUNNING_TIME_SECONDS);
Prediction prediction = handler.getPrediction(request);
assertTrue(prediction instanceof CounterfactualPrediction);
CounterfactualPrediction counterfactualPrediction = (CounterfactualPrediction) prediction;
assertEquals(1, counterfactualPrediction.getInput().getFeatures().size());
Optional<Feature> oInput1 = counterfactualPrediction.getInput().getFeatures().stream().filter(f -> f.getName().equals("input1")).findFirst();
assertTrue(oInput1.isPresent());
Feature input1 = oInput1.get();
assertEquals(Type.NUMBER, input1.getType());
assertEquals(20, input1.getValue().asNumber());
assertTrue(counterfactualPrediction.getInput().getFeatures().stream().allMatch(f -> f.getDomain().isEmpty()));
assertTrue(counterfactualPrediction.getInput().getFeatures().stream().allMatch(Feature::isConstrained));
assertTrue(counterfactualPrediction.getOutput().getOutputs().isEmpty());
assertEquals(counterfactualPrediction.getMaxRunningTimeSeconds(), request.getMaxRunningTimeSeconds());
}
use of org.kie.kogito.explainability.api.NamedTypedValue in project kogito-apps by kiegroup.
the class CounterfactualExplainerServiceHandlerTest method testGetPredictionWithFlatOutputModel.
@Test
public void testGetPredictionWithFlatOutputModel() {
CounterfactualExplainabilityRequest request = new CounterfactualExplainabilityRequest(EXECUTION_ID, SERVICE_URL, MODEL_IDENTIFIER, COUNTERFACTUAL_ID, Collections.emptyList(), List.of(new NamedTypedValue("output1", new UnitValue("number", new IntNode(20)))), Collections.emptyList(), MAX_RUNNING_TIME_SECONDS);
Prediction prediction = handler.getPrediction(request);
assertTrue(prediction instanceof CounterfactualPrediction);
CounterfactualPrediction counterfactualPrediction = (CounterfactualPrediction) prediction;
assertEquals(1, counterfactualPrediction.getOutput().getOutputs().size());
Optional<Output> oOutput1 = counterfactualPrediction.getOutput().getOutputs().stream().filter(f -> f.getName().equals("output1")).findFirst();
assertTrue(oOutput1.isPresent());
Output output1 = oOutput1.get();
assertEquals(Type.NUMBER, output1.getType());
assertEquals(20, output1.getValue().asNumber());
assertTrue(counterfactualPrediction.getInput().getFeatures().isEmpty());
assertEquals(counterfactualPrediction.getMaxRunningTimeSeconds(), request.getMaxRunningTimeSeconds());
}
use of org.kie.kogito.explainability.api.NamedTypedValue in project kogito-apps by kiegroup.
the class CounterfactualExplainerServiceHandlerTest method testGetPredictionWithCollectionInputModel.
@Test
public void testGetPredictionWithCollectionInputModel() {
CounterfactualExplainabilityRequest request = new CounterfactualExplainabilityRequest(EXECUTION_ID, SERVICE_URL, MODEL_IDENTIFIER, COUNTERFACTUAL_ID, List.of(new NamedTypedValue("input1", new CollectionValue("number", List.of(new UnitValue("number", new IntNode(100)))))), Collections.emptyList(), Collections.emptyList(), MAX_RUNNING_TIME_SECONDS);
assertThrows(IllegalArgumentException.class, () -> handler.getPrediction(request));
}
Aggregations