use of org.kie.kogito.explainability.api.CounterfactualDomainRange in project kogito-apps by kiegroup.
the class CounterfactualExplainabilityRequestMarshallerTest method testWriteAndRead.
@Test
public void testWriteAndRead() throws IOException {
ModelIdentifier modelIdentifier = new ModelIdentifier("resourceType", "resourceId");
List<NamedTypedValue> originalInputs = Collections.singletonList(new NamedTypedValue("unitIn", new UnitValue("number", "number", JsonNodeFactory.instance.numberNode(10))));
List<NamedTypedValue> goals = Collections.singletonList(new NamedTypedValue("unitIn", new UnitValue("number", "number", JsonNodeFactory.instance.numberNode(10))));
List<CounterfactualSearchDomain> searchDomains = Collections.singletonList(new CounterfactualSearchDomain("age", new CounterfactualSearchDomainUnitValue("integer", "integer", Boolean.TRUE, new CounterfactualDomainRange(JsonNodeFactory.instance.numberNode(0), JsonNodeFactory.instance.numberNode(10)))));
CounterfactualExplainabilityRequest request = new CounterfactualExplainabilityRequest("executionId", "serviceUrl", modelIdentifier, "counterfactualId", originalInputs, goals, searchDomains, 60L);
CounterfactualExplainabilityRequestMarshaller marshaller = new CounterfactualExplainabilityRequestMarshaller(new ObjectMapper());
marshaller.writeTo(writer, request);
CounterfactualExplainabilityRequest retrieved = marshaller.readFrom(reader);
Assertions.assertEquals(request.getExecutionId(), retrieved.getExecutionId());
Assertions.assertEquals(request.getCounterfactualId(), retrieved.getCounterfactualId());
Assertions.assertEquals(goals.get(0).getName(), retrieved.getGoals().stream().findFirst().get().getName());
Assertions.assertEquals(searchDomains.get(0).getName(), retrieved.getSearchDomains().stream().findFirst().get().getName());
Assertions.assertEquals(0, ((CounterfactualDomainRange) retrieved.getSearchDomains().stream().findFirst().get().getValue().toUnit().getDomain()).getLowerBound().asInt());
Assertions.assertEquals(60L, request.getMaxRunningTimeSeconds());
}
use of org.kie.kogito.explainability.api.CounterfactualDomainRange in project kogito-apps by kiegroup.
the class CounterfactualExplainerServiceHandlerTest method testGetPredictionWithStructuredSearchDomains.
@Test
public void testGetPredictionWithStructuredSearchDomains() {
CounterfactualExplainabilityRequest request = new CounterfactualExplainabilityRequest(EXECUTION_ID, SERVICE_URL, MODEL_IDENTIFIER, COUNTERFACTUAL_ID, Collections.emptyList(), Collections.emptyList(), List.of(new CounterfactualSearchDomain("input1", new CounterfactualSearchDomainStructureValue("number", Map.of("input2b", new CounterfactualSearchDomainUnitValue("number", "number", true, new CounterfactualDomainRange(new IntNode(10), new IntNode(20))))))), MAX_RUNNING_TIME_SECONDS);
assertThrows(IllegalArgumentException.class, () -> handler.getPrediction(request));
}
use of org.kie.kogito.explainability.api.CounterfactualDomainRange in project kogito-apps by kiegroup.
the class ConversionUtilsTest method testToFeatureDomain_UnitRangeDouble.
@Test
void testToFeatureDomain_UnitRangeDouble() {
FeatureDomain featureDomain = ConversionUtils.toFeatureDomain(new CounterfactualSearchDomainUnitValue("double", "double", true, new CounterfactualDomainRange(DoubleNode.valueOf(-273.15), DoubleNode.valueOf(Double.MAX_VALUE))));
assertTrue(featureDomain instanceof NumericalFeatureDomain);
NumericalFeatureDomain numericalFeatureDomain = (NumericalFeatureDomain) featureDomain;
assertEquals(-273.15, numericalFeatureDomain.getLowerBound());
assertEquals(Double.MAX_VALUE, numericalFeatureDomain.getUpperBound());
assertNull(numericalFeatureDomain.getCategories());
}
use of org.kie.kogito.explainability.api.CounterfactualDomainRange in project kogito-apps by kiegroup.
the class ConversionUtils method toCounterfactualSearchDomain.
static Optional<FeatureDomain> toCounterfactualSearchDomain(CounterfactualDomain domain) {
if (Objects.isNull(domain)) {
return Optional.of(EmptyFeatureDomain.create());
} else if (domain instanceof CounterfactualDomainRange) {
CounterfactualDomainRange range = (CounterfactualDomainRange) domain;
JsonNode lb = range.getLowerBound();
JsonNode ub = range.getUpperBound();
if (lb.isNumber() && ub.isNumber()) {
return Optional.of(NumericalFeatureDomain.create(range.getLowerBound().asDouble(), range.getUpperBound().asDouble()));
} else {
throw new IllegalArgumentException(String.format("Unsupported CounterfactualDomainRange [%s, %s]", lb.asText(), ub.asText()));
}
} else if (domain instanceof CounterfactualDomainCategorical) {
CounterfactualDomainCategorical categorical = (CounterfactualDomainCategorical) domain;
Collection<JsonNode> jsonCategories = categorical.getCategories();
String[] categories = new String[jsonCategories.size()];
if (jsonCategories.stream().allMatch(JsonNode::isTextual)) {
jsonCategories.stream().map(JsonNode::asText).collect(Collectors.toList()).toArray(categories);
return Optional.of(CategoricalFeatureDomain.create(categories));
} else {
throw new IllegalArgumentException(String.format("Unsupported CounterfactualDomainCategorical [%s]", String.join(", ", categories)));
}
}
return Optional.empty();
}
use of org.kie.kogito.explainability.api.CounterfactualDomainRange in project kogito-apps by kiegroup.
the class AbstractTrustyServiceIT method testCounterfactuals_StoreSingleAndRetrieveSingleWithEmptyDefinition.
@Test
public void testCounterfactuals_StoreSingleAndRetrieveSingleWithEmptyDefinition() {
String executionId = "myCFExecution1";
storeExecution(executionId, 0L);
// The Goals structures must be comparable to the original decisions outcomes.
// The Search Domain structures must be identical those of the original decision inputs.
CounterfactualSearchDomain searchDomain = buildSearchDomainUnit("test", "number", new CounterfactualDomainRange(new IntNode(1), new IntNode(2)));
CounterfactualExplainabilityRequest request = trustyService.requestCounterfactuals(executionId, Collections.emptyList(), Collections.singletonList(searchDomain));
assertNotNull(request);
assertEquals(request.getExecutionId(), executionId);
assertNotNull(request.getCounterfactualId());
CounterfactualExplainabilityRequest result = trustyService.getCounterfactualRequest(executionId, request.getCounterfactualId());
assertNotNull(result);
assertEquals(request.getExecutionId(), result.getExecutionId());
assertEquals(request.getCounterfactualId(), result.getCounterfactualId());
}
Aggregations