Search in sources :

Example 1 with CounterfactualSearchDomain

use of org.kie.kogito.explainability.api.CounterfactualSearchDomain in project kogito-apps by kiegroup.

the class CounterfactualExplainabilityRequestMarshallerTest method testWriteAndRead.

@Test
public void testWriteAndRead() throws IOException {
    ModelIdentifier modelIdentifier = new ModelIdentifier("resourceType", "resourceId");
    List<NamedTypedValue> originalInputs = Collections.singletonList(new NamedTypedValue("unitIn", new UnitValue("number", "number", JsonNodeFactory.instance.numberNode(10))));
    List<NamedTypedValue> goals = Collections.singletonList(new NamedTypedValue("unitIn", new UnitValue("number", "number", JsonNodeFactory.instance.numberNode(10))));
    List<CounterfactualSearchDomain> searchDomains = Collections.singletonList(new CounterfactualSearchDomain("age", new CounterfactualSearchDomainUnitValue("integer", "integer", Boolean.TRUE, new CounterfactualDomainRange(JsonNodeFactory.instance.numberNode(0), JsonNodeFactory.instance.numberNode(10)))));
    CounterfactualExplainabilityRequest request = new CounterfactualExplainabilityRequest("executionId", "serviceUrl", modelIdentifier, "counterfactualId", originalInputs, goals, searchDomains, 60L);
    CounterfactualExplainabilityRequestMarshaller marshaller = new CounterfactualExplainabilityRequestMarshaller(new ObjectMapper());
    marshaller.writeTo(writer, request);
    CounterfactualExplainabilityRequest retrieved = marshaller.readFrom(reader);
    Assertions.assertEquals(request.getExecutionId(), retrieved.getExecutionId());
    Assertions.assertEquals(request.getCounterfactualId(), retrieved.getCounterfactualId());
    Assertions.assertEquals(goals.get(0).getName(), retrieved.getGoals().stream().findFirst().get().getName());
    Assertions.assertEquals(searchDomains.get(0).getName(), retrieved.getSearchDomains().stream().findFirst().get().getName());
    Assertions.assertEquals(0, ((CounterfactualDomainRange) retrieved.getSearchDomains().stream().findFirst().get().getValue().toUnit().getDomain()).getLowerBound().asInt());
    Assertions.assertEquals(60L, request.getMaxRunningTimeSeconds());
}
Also used : CounterfactualExplainabilityRequest(org.kie.kogito.explainability.api.CounterfactualExplainabilityRequest) NamedTypedValue(org.kie.kogito.explainability.api.NamedTypedValue) CounterfactualDomainRange(org.kie.kogito.explainability.api.CounterfactualDomainRange) ModelIdentifier(org.kie.kogito.explainability.api.ModelIdentifier) UnitValue(org.kie.kogito.tracing.typedvalue.UnitValue) CounterfactualSearchDomainUnitValue(org.kie.kogito.explainability.api.CounterfactualSearchDomainUnitValue) CounterfactualSearchDomainUnitValue(org.kie.kogito.explainability.api.CounterfactualSearchDomainUnitValue) ObjectMapper(com.fasterxml.jackson.databind.ObjectMapper) CounterfactualSearchDomain(org.kie.kogito.explainability.api.CounterfactualSearchDomain) Test(org.junit.jupiter.api.Test)

Example 2 with CounterfactualSearchDomain

use of org.kie.kogito.explainability.api.CounterfactualSearchDomain in project kogito-apps by kiegroup.

the class CounterfactualExplainerServiceHandlerTest method testGetPredictionWithStructuredSearchDomains.

@Test
public void testGetPredictionWithStructuredSearchDomains() {
    CounterfactualExplainabilityRequest request = new CounterfactualExplainabilityRequest(EXECUTION_ID, SERVICE_URL, MODEL_IDENTIFIER, COUNTERFACTUAL_ID, Collections.emptyList(), Collections.emptyList(), List.of(new CounterfactualSearchDomain("input1", new CounterfactualSearchDomainStructureValue("number", Map.of("input2b", new CounterfactualSearchDomainUnitValue("number", "number", true, new CounterfactualDomainRange(new IntNode(10), new IntNode(20))))))), MAX_RUNNING_TIME_SECONDS);
    assertThrows(IllegalArgumentException.class, () -> handler.getPrediction(request));
}
Also used : CounterfactualExplainabilityRequest(org.kie.kogito.explainability.api.CounterfactualExplainabilityRequest) IntNode(com.fasterxml.jackson.databind.node.IntNode) CounterfactualDomainRange(org.kie.kogito.explainability.api.CounterfactualDomainRange) CounterfactualSearchDomainUnitValue(org.kie.kogito.explainability.api.CounterfactualSearchDomainUnitValue) CounterfactualSearchDomain(org.kie.kogito.explainability.api.CounterfactualSearchDomain) CounterfactualSearchDomainStructureValue(org.kie.kogito.explainability.api.CounterfactualSearchDomainStructureValue) Test(org.junit.jupiter.api.Test)

Example 3 with CounterfactualSearchDomain

use of org.kie.kogito.explainability.api.CounterfactualSearchDomain in project kogito-apps by kiegroup.

the class CounterfactualExplainerServiceHandler method getPrediction.

@Override
public Prediction getPrediction(CounterfactualExplainabilityRequest request) {
    Collection<NamedTypedValue> goals = toMapBasedSorting(request.getGoals());
    Collection<CounterfactualSearchDomain> searchDomains = request.getSearchDomains();
    Collection<NamedTypedValue> originalInputs = request.getOriginalInputs();
    Long maxRunningTimeSeconds = request.getMaxRunningTimeSeconds();
    if (Objects.nonNull(maxRunningTimeSeconds)) {
        if (maxRunningTimeSeconds > kafkaMaxRecordAgeSeconds) {
            LOGGER.info(String.format("Maximum Running Timeout set to '%d's since the provided value '%d's exceeded the Messaging sub-system configuration '%d's.", kafkaMaxRecordAgeSeconds, maxRunningTimeSeconds, kafkaMaxRecordAgeSeconds));
            maxRunningTimeSeconds = kafkaMaxRecordAgeSeconds;
        }
    }
    // See https://issues.redhat.com/browse/FAI-473 and https://issues.redhat.com/browse/FAI-474
    if (isUnsupportedModel(originalInputs, goals, searchDomains)) {
        throw new IllegalArgumentException("Counterfactual explanations only support flat models.");
    }
    PredictionInput input = new PredictionInput(toFeatureList(originalInputs, searchDomains));
    PredictionOutput output = new PredictionOutput(toOutputList(goals));
    return new CounterfactualPrediction(input, output, null, UUID.fromString(request.getExecutionId()), maxRunningTimeSeconds);
}
Also used : NamedTypedValue(org.kie.kogito.explainability.api.NamedTypedValue) PredictionInput(org.kie.kogito.explainability.model.PredictionInput) PredictionOutput(org.kie.kogito.explainability.model.PredictionOutput) CounterfactualSearchDomain(org.kie.kogito.explainability.api.CounterfactualSearchDomain) CounterfactualPrediction(org.kie.kogito.explainability.model.CounterfactualPrediction)

Example 4 with CounterfactualSearchDomain

use of org.kie.kogito.explainability.api.CounterfactualSearchDomain in project kogito-apps by kiegroup.

the class ConversionUtils method toFeatureList.

// //////////////////////////
// TO EXPLAINABILITY MODEL
// //////////////////////////
/*
     * ---------------------------------------
     * Feature conversion
     * ---------------------------------------
     */
public static List<Feature> toFeatureList(Collection<? extends HasNameValue<TypedValue>> values, Collection<CounterfactualSearchDomain> searchDomains) {
    if (searchDomains.isEmpty()) {
        return toFeatureList(values);
    } else {
        AtomicInteger index = new AtomicInteger();
        final List<FeatureDomain> featureDomains = toFeatureDomainList(searchDomains);
        final List<Boolean> featureConstraints = toFeatureConstraintList(searchDomains);
        return values.stream().map(hnv -> {
            final String name = hnv.getName();
            final TypedValue value = hnv.getValue();
            final int i = index.getAndIncrement();
            return toFeature(name, value, featureDomains.get(i), featureConstraints.get(i));
        }).collect(Collectors.toList());
    }
}
Also used : FeatureFactory(org.kie.kogito.explainability.model.FeatureFactory) Feature(org.kie.kogito.explainability.model.Feature) BiFunction(java.util.function.BiFunction) CounterfactualSearchDomainValue(org.kie.kogito.explainability.api.CounterfactualSearchDomainValue) CounterfactualDomainRange(org.kie.kogito.explainability.api.CounterfactualDomainRange) Value(org.kie.kogito.explainability.model.Value) Function(java.util.function.Function) ArrayList(java.util.ArrayList) Pair(org.apache.commons.lang3.tuple.Pair) AtomicInteger(java.util.concurrent.atomic.AtomicInteger) Map(java.util.Map) HasNameValue(org.kie.kogito.explainability.api.HasNameValue) EmptyFeatureDomain(org.kie.kogito.explainability.model.domain.EmptyFeatureDomain) JsonNode(com.fasterxml.jackson.databind.JsonNode) JsonObject(io.vertx.core.json.JsonObject) CounterfactualDomainCategorical(org.kie.kogito.explainability.api.CounterfactualDomainCategorical) NamedTypedValue(org.kie.kogito.explainability.api.NamedTypedValue) CategoricalFeatureDomain(org.kie.kogito.explainability.model.domain.CategoricalFeatureDomain) Collection(java.util.Collection) Collectors(java.util.stream.Collectors) Type(org.kie.kogito.explainability.model.Type) TextNode(com.fasterxml.jackson.databind.node.TextNode) TypedValue(org.kie.kogito.tracing.typedvalue.TypedValue) UnitValue(org.kie.kogito.tracing.typedvalue.UnitValue) Objects(java.util.Objects) List(java.util.List) CounterfactualDomain(org.kie.kogito.explainability.api.CounterfactualDomain) CollectionValue(org.kie.kogito.tracing.typedvalue.CollectionValue) DoubleNode(com.fasterxml.jackson.databind.node.DoubleNode) NumericalFeatureDomain(org.kie.kogito.explainability.model.domain.NumericalFeatureDomain) Output(org.kie.kogito.explainability.model.Output) Optional(java.util.Optional) BooleanNode(com.fasterxml.jackson.databind.node.BooleanNode) CounterfactualSearchDomain(org.kie.kogito.explainability.api.CounterfactualSearchDomain) Collections(java.util.Collections) FeatureDomain(org.kie.kogito.explainability.model.domain.FeatureDomain) AtomicInteger(java.util.concurrent.atomic.AtomicInteger) EmptyFeatureDomain(org.kie.kogito.explainability.model.domain.EmptyFeatureDomain) CategoricalFeatureDomain(org.kie.kogito.explainability.model.domain.CategoricalFeatureDomain) NumericalFeatureDomain(org.kie.kogito.explainability.model.domain.NumericalFeatureDomain) FeatureDomain(org.kie.kogito.explainability.model.domain.FeatureDomain) NamedTypedValue(org.kie.kogito.explainability.api.NamedTypedValue) TypedValue(org.kie.kogito.tracing.typedvalue.TypedValue)

Example 5 with CounterfactualSearchDomain

use of org.kie.kogito.explainability.api.CounterfactualSearchDomain in project kogito-apps by kiegroup.

the class TrustyServiceTest method doGivenStoredExecutionWhenCounterfactualRequestIsMadeThenExplainabilityEventIsEmittedTest.

@SuppressWarnings("unchecked")
void doGivenStoredExecutionWhenCounterfactualRequestIsMadeThenExplainabilityEventIsEmittedTest(CounterfactualDomain domain) {
    Storage<String, Decision> decisionStorage = mock(Storage.class);
    Storage<String, CounterfactualExplainabilityRequest> counterfactualStorage = mock(Storage.class);
    ArgumentCaptor<BaseExplainabilityRequest> explainabilityEventArgumentCaptor = ArgumentCaptor.forClass(BaseExplainabilityRequest.class);
    when(decisionStorage.containsKey(eq(TEST_EXECUTION_ID))).thenReturn(true);
    when(trustyStorageServiceMock.getDecisionsStorage()).thenReturn(decisionStorage);
    when(trustyStorageServiceMock.getCounterfactualRequestStorage()).thenReturn(counterfactualStorage);
    when(decisionStorage.get(eq(TEST_EXECUTION_ID))).thenReturn(TrustyServiceTestUtils.buildCorrectDecision(TEST_EXECUTION_ID));
    // The Goals structures must be comparable to the original decisions outcomes.
    // The Search Domain structures must be identical those of the original decision inputs.
    trustyService.requestCounterfactuals(TEST_EXECUTION_ID, List.of(new NamedTypedValue("Fine", new StructureValue("tFine", Map.of("Amount", new UnitValue("number", "number", new IntNode(0)), "Points", new UnitValue("number", "number", new IntNode(0))))), new NamedTypedValue("Should the driver be suspended?", new UnitValue("string", "string", new TextNode("No")))), List.of(new CounterfactualSearchDomain("Violation", new CounterfactualSearchDomainStructureValue("tViolation", Map.of("Type", new CounterfactualSearchDomainUnitValue("string", "string", true, domain), "Actual Speed", new CounterfactualSearchDomainUnitValue("number", "number", true, domain), "Speed Limit", new CounterfactualSearchDomainUnitValue("number", "number", true, domain)))), new CounterfactualSearchDomain("Driver", new CounterfactualSearchDomainStructureValue("tDriver", Map.of("Age", new CounterfactualSearchDomainUnitValue("number", "number", true, domain), "Points", new CounterfactualSearchDomainUnitValue("number", "number", true, domain))))));
    verify(explainabilityRequestProducerMock).sendEvent(explainabilityEventArgumentCaptor.capture());
    BaseExplainabilityRequest event = explainabilityEventArgumentCaptor.getValue();
    assertNotNull(event);
    assertTrue(event instanceof CounterfactualExplainabilityRequest);
    CounterfactualExplainabilityRequest request = (CounterfactualExplainabilityRequest) event;
    assertEquals(TEST_EXECUTION_ID, request.getExecutionId());
}
Also used : CounterfactualExplainabilityRequest(org.kie.kogito.explainability.api.CounterfactualExplainabilityRequest) UnitValue(org.kie.kogito.tracing.typedvalue.UnitValue) CounterfactualSearchDomainUnitValue(org.kie.kogito.explainability.api.CounterfactualSearchDomainUnitValue) TextNode(com.fasterxml.jackson.databind.node.TextNode) ArgumentMatchers.anyString(org.mockito.ArgumentMatchers.anyString) Decision(org.kie.kogito.trusty.storage.api.model.decision.Decision) NamedTypedValue(org.kie.kogito.explainability.api.NamedTypedValue) IntNode(com.fasterxml.jackson.databind.node.IntNode) BaseExplainabilityRequest(org.kie.kogito.explainability.api.BaseExplainabilityRequest) StructureValue(org.kie.kogito.tracing.typedvalue.StructureValue) CounterfactualSearchDomainStructureValue(org.kie.kogito.explainability.api.CounterfactualSearchDomainStructureValue) CounterfactualSearchDomainUnitValue(org.kie.kogito.explainability.api.CounterfactualSearchDomainUnitValue) CounterfactualSearchDomain(org.kie.kogito.explainability.api.CounterfactualSearchDomain) CounterfactualSearchDomainStructureValue(org.kie.kogito.explainability.api.CounterfactualSearchDomainStructureValue)

Aggregations

CounterfactualSearchDomain (org.kie.kogito.explainability.api.CounterfactualSearchDomain)28 Test (org.junit.jupiter.api.Test)22 CounterfactualDomainRange (org.kie.kogito.explainability.api.CounterfactualDomainRange)19 CounterfactualExplainabilityRequest (org.kie.kogito.explainability.api.CounterfactualExplainabilityRequest)19 NamedTypedValue (org.kie.kogito.explainability.api.NamedTypedValue)17 IntNode (com.fasterxml.jackson.databind.node.IntNode)16 CounterfactualSearchDomainUnitValue (org.kie.kogito.explainability.api.CounterfactualSearchDomainUnitValue)12 TextNode (com.fasterxml.jackson.databind.node.TextNode)11 ArrayList (java.util.ArrayList)9 CounterfactualDomainCategorical (org.kie.kogito.explainability.api.CounterfactualDomainCategorical)9 UnitValue (org.kie.kogito.tracing.typedvalue.UnitValue)9 List (java.util.List)6 JsonNode (com.fasterxml.jackson.databind.JsonNode)5 Map (java.util.Map)4 CounterfactualDomain (org.kie.kogito.explainability.api.CounterfactualDomain)4 Feature (org.kie.kogito.explainability.model.Feature)4 StructureValue (org.kie.kogito.tracing.typedvalue.StructureValue)4 CounterfactualRequestResponse (org.kie.kogito.trusty.service.common.responses.CounterfactualRequestResponse)4 Decision (org.kie.kogito.trusty.storage.api.model.decision.Decision)4 ArgumentMatchers.anyString (org.mockito.ArgumentMatchers.anyString)4