Search in sources :

Example 16 with UnitValue

use of org.kie.kogito.tracing.typedvalue.UnitValue in project kogito-apps by kiegroup.

the class CounterfactualExplainabilityResultMarshallerTest method testWriteAndRead.

@Test
public void testWriteAndRead() throws IOException {
    List<NamedTypedValue> inputs = List.of(new NamedTypedValue("unitIn", new UnitValue("number", "number", JsonNodeFactory.instance.numberNode(10))));
    List<NamedTypedValue> outputs = List.of(new NamedTypedValue("unitOut", new UnitValue("number", "number", JsonNodeFactory.instance.numberNode(55))));
    CounterfactualExplainabilityResult explainabilityResult = new CounterfactualExplainabilityResult("executionId", "counterfactualId", "solutionId", 0L, ExplainabilityStatus.SUCCEEDED, "statusDetail", true, CounterfactualExplainabilityResult.Stage.FINAL, inputs, outputs);
    CounterfactualExplainabilityResultMarshaller marshaller = new CounterfactualExplainabilityResultMarshaller(new ObjectMapper());
    marshaller.writeTo(writer, explainabilityResult);
    CounterfactualExplainabilityResult retrieved = marshaller.readFrom(reader);
    List<NamedTypedValue> retrievedInputs = List.of(retrieved.getInputs().toArray(new NamedTypedValue[0]));
    List<NamedTypedValue> retrievedOutputs = List.of(retrieved.getOutputs().toArray(new NamedTypedValue[0]));
    Assertions.assertEquals(explainabilityResult.getExecutionId(), retrieved.getExecutionId());
    Assertions.assertEquals(explainabilityResult.getCounterfactualId(), retrieved.getCounterfactualId());
    Assertions.assertEquals(explainabilityResult.getSolutionId(), retrieved.getSolutionId());
    Assertions.assertEquals(explainabilityResult.getSequenceId(), retrieved.getSequenceId());
    Assertions.assertEquals(explainabilityResult.getStatus(), retrieved.getStatus());
    Assertions.assertEquals(explainabilityResult.getStatusDetails(), retrieved.getStatusDetails());
    Assertions.assertEquals(explainabilityResult.getStage(), retrieved.getStage());
    Assertions.assertEquals(1, retrievedInputs.size());
    Assertions.assertEquals(inputs.get(0).getName(), retrievedInputs.get(0).getName());
    Assertions.assertEquals(inputs.get(0).getValue().getKind(), retrievedInputs.get(0).getValue().getKind());
    Assertions.assertEquals(inputs.get(0).getValue().getType(), retrievedInputs.get(0).getValue().getType());
    Assertions.assertEquals(inputs.get(0).getValue().toUnit().getValue(), retrievedInputs.get(0).getValue().toUnit().getValue());
    Assertions.assertEquals(1, retrievedOutputs.size());
    Assertions.assertEquals(outputs.get(0).getName(), retrievedOutputs.get(0).getName());
    Assertions.assertEquals(outputs.get(0).getValue().getKind(), retrievedOutputs.get(0).getValue().getKind());
    Assertions.assertEquals(outputs.get(0).getValue().getType(), retrievedOutputs.get(0).getValue().getType());
    Assertions.assertEquals(outputs.get(0).getValue().toUnit().getValue(), retrievedOutputs.get(0).getValue().toUnit().getValue());
}
Also used : NamedTypedValue(org.kie.kogito.explainability.api.NamedTypedValue) UnitValue(org.kie.kogito.tracing.typedvalue.UnitValue) CounterfactualExplainabilityResult(org.kie.kogito.explainability.api.CounterfactualExplainabilityResult) ObjectMapper(com.fasterxml.jackson.databind.ObjectMapper) Test(org.junit.jupiter.api.Test)

Example 17 with UnitValue

use of org.kie.kogito.tracing.typedvalue.UnitValue in project kogito-apps by kiegroup.

the class CounterfactualExplainerServiceHandlerTest method testGetPredictionWithStructuredOutputModel.

@Test
public void testGetPredictionWithStructuredOutputModel() {
    CounterfactualExplainabilityRequest request = new CounterfactualExplainabilityRequest(EXECUTION_ID, SERVICE_URL, MODEL_IDENTIFIER, COUNTERFACTUAL_ID, Collections.emptyList(), List.of(new NamedTypedValue("input1", new StructureValue("number", Map.of("input2b", new UnitValue("number", new IntNode(55)))))), Collections.emptyList(), MAX_RUNNING_TIME_SECONDS);
    assertThrows(IllegalArgumentException.class, () -> handler.getPrediction(request));
}
Also used : CounterfactualExplainabilityRequest(org.kie.kogito.explainability.api.CounterfactualExplainabilityRequest) NamedTypedValue(org.kie.kogito.explainability.api.NamedTypedValue) IntNode(com.fasterxml.jackson.databind.node.IntNode) UnitValue(org.kie.kogito.tracing.typedvalue.UnitValue) CounterfactualSearchDomainUnitValue(org.kie.kogito.explainability.api.CounterfactualSearchDomainUnitValue) StructureValue(org.kie.kogito.tracing.typedvalue.StructureValue) CounterfactualSearchDomainStructureValue(org.kie.kogito.explainability.api.CounterfactualSearchDomainStructureValue) Test(org.junit.jupiter.api.Test)

Example 18 with UnitValue

use of org.kie.kogito.tracing.typedvalue.UnitValue in project kogito-apps by kiegroup.

the class CounterfactualExplainerServiceHandlerTest method testGetPredictionWithFlatInputModel.

@Test
public void testGetPredictionWithFlatInputModel() {
    CounterfactualExplainabilityRequest request = new CounterfactualExplainabilityRequest(EXECUTION_ID, SERVICE_URL, MODEL_IDENTIFIER, COUNTERFACTUAL_ID, List.of(new NamedTypedValue("input1", new UnitValue("number", new IntNode(20)))), Collections.emptyList(), Collections.emptyList(), MAX_RUNNING_TIME_SECONDS);
    Prediction prediction = handler.getPrediction(request);
    assertTrue(prediction instanceof CounterfactualPrediction);
    CounterfactualPrediction counterfactualPrediction = (CounterfactualPrediction) prediction;
    assertEquals(1, counterfactualPrediction.getInput().getFeatures().size());
    Optional<Feature> oInput1 = counterfactualPrediction.getInput().getFeatures().stream().filter(f -> f.getName().equals("input1")).findFirst();
    assertTrue(oInput1.isPresent());
    Feature input1 = oInput1.get();
    assertEquals(Type.NUMBER, input1.getType());
    assertEquals(20, input1.getValue().asNumber());
    assertTrue(counterfactualPrediction.getInput().getFeatures().stream().allMatch(f -> f.getDomain().isEmpty()));
    assertTrue(counterfactualPrediction.getInput().getFeatures().stream().allMatch(Feature::isConstrained));
    assertTrue(counterfactualPrediction.getOutput().getOutputs().isEmpty());
    assertEquals(counterfactualPrediction.getMaxRunningTimeSeconds(), request.getMaxRunningTimeSeconds());
}
Also used : CounterfactualExplainabilityRequest(org.kie.kogito.explainability.api.CounterfactualExplainabilityRequest) BeforeEach(org.junit.jupiter.api.BeforeEach) BaseExplainabilityRequest(org.kie.kogito.explainability.api.BaseExplainabilityRequest) Feature(org.kie.kogito.explainability.model.Feature) ArgumentMatchers.eq(org.mockito.ArgumentMatchers.eq) CounterfactualDomainRange(org.kie.kogito.explainability.api.CounterfactualDomainRange) CounterfactualEntity(org.kie.kogito.explainability.local.counterfactual.entities.CounterfactualEntity) Value(org.kie.kogito.explainability.model.Value) CounterfactualResult(org.kie.kogito.explainability.local.counterfactual.CounterfactualResult) Assertions.assertFalse(org.junit.jupiter.api.Assertions.assertFalse) Map(java.util.Map) EmptyFeatureDomain(org.kie.kogito.explainability.model.domain.EmptyFeatureDomain) BaseExplainabilityResult(org.kie.kogito.explainability.api.BaseExplainabilityResult) PredictionOutput(org.kie.kogito.explainability.model.PredictionOutput) CounterfactualExplainabilityRequest(org.kie.kogito.explainability.api.CounterfactualExplainabilityRequest) NamedTypedValue(org.kie.kogito.explainability.api.NamedTypedValue) ExplainabilityStatus(org.kie.kogito.explainability.api.ExplainabilityStatus) UUID(java.util.UUID) Collectors(java.util.stream.Collectors) TypedValue(org.kie.kogito.tracing.typedvalue.TypedValue) UnitValue(org.kie.kogito.tracing.typedvalue.UnitValue) Test(org.junit.jupiter.api.Test) List(java.util.List) CollectionValue(org.kie.kogito.tracing.typedvalue.CollectionValue) Output(org.kie.kogito.explainability.model.Output) Assertions.assertTrue(org.junit.jupiter.api.Assertions.assertTrue) Optional(java.util.Optional) CounterfactualExplainabilityResult(org.kie.kogito.explainability.api.CounterfactualExplainabilityResult) CounterfactualSearchDomain(org.kie.kogito.explainability.api.CounterfactualSearchDomain) Mockito.mock(org.mockito.Mockito.mock) Assertions.assertThrows(org.junit.jupiter.api.Assertions.assertThrows) ArgumentMatchers.any(org.mockito.ArgumentMatchers.any) IntNode(com.fasterxml.jackson.databind.node.IntNode) Prediction(org.kie.kogito.explainability.model.Prediction) StructureValue(org.kie.kogito.tracing.typedvalue.StructureValue) CounterfactualSearchDomainStructureValue(org.kie.kogito.explainability.api.CounterfactualSearchDomainStructureValue) CounterfactualPrediction(org.kie.kogito.explainability.model.CounterfactualPrediction) PredictionProviderFactory(org.kie.kogito.explainability.PredictionProviderFactory) Assertions.assertEquals(org.junit.jupiter.api.Assertions.assertEquals) CounterfactualSearchDomainCollectionValue(org.kie.kogito.explainability.api.CounterfactualSearchDomainCollectionValue) CounterfactualExplainer(org.kie.kogito.explainability.local.counterfactual.CounterfactualExplainer) Type(org.kie.kogito.explainability.model.Type) PredictionProvider(org.kie.kogito.explainability.model.PredictionProvider) Mockito.verify(org.mockito.Mockito.verify) Consumer(java.util.function.Consumer) DoubleEntity(org.kie.kogito.explainability.local.counterfactual.entities.DoubleEntity) CounterfactualSearchDomainUnitValue(org.kie.kogito.explainability.api.CounterfactualSearchDomainUnitValue) DoubleNode(com.fasterxml.jackson.databind.node.DoubleNode) NumericalFeatureDomain(org.kie.kogito.explainability.model.domain.NumericalFeatureDomain) BooleanNode(com.fasterxml.jackson.databind.node.BooleanNode) Collections(java.util.Collections) ModelIdentifier(org.kie.kogito.explainability.api.ModelIdentifier) NamedTypedValue(org.kie.kogito.explainability.api.NamedTypedValue) IntNode(com.fasterxml.jackson.databind.node.IntNode) Prediction(org.kie.kogito.explainability.model.Prediction) CounterfactualPrediction(org.kie.kogito.explainability.model.CounterfactualPrediction) UnitValue(org.kie.kogito.tracing.typedvalue.UnitValue) CounterfactualSearchDomainUnitValue(org.kie.kogito.explainability.api.CounterfactualSearchDomainUnitValue) Feature(org.kie.kogito.explainability.model.Feature) CounterfactualPrediction(org.kie.kogito.explainability.model.CounterfactualPrediction) Test(org.junit.jupiter.api.Test)

Example 19 with UnitValue

use of org.kie.kogito.tracing.typedvalue.UnitValue in project kogito-apps by kiegroup.

the class CounterfactualExplainerServiceHandlerTest method testGetPredictionWithFlatOutputModel.

@Test
public void testGetPredictionWithFlatOutputModel() {
    CounterfactualExplainabilityRequest request = new CounterfactualExplainabilityRequest(EXECUTION_ID, SERVICE_URL, MODEL_IDENTIFIER, COUNTERFACTUAL_ID, Collections.emptyList(), List.of(new NamedTypedValue("output1", new UnitValue("number", new IntNode(20)))), Collections.emptyList(), MAX_RUNNING_TIME_SECONDS);
    Prediction prediction = handler.getPrediction(request);
    assertTrue(prediction instanceof CounterfactualPrediction);
    CounterfactualPrediction counterfactualPrediction = (CounterfactualPrediction) prediction;
    assertEquals(1, counterfactualPrediction.getOutput().getOutputs().size());
    Optional<Output> oOutput1 = counterfactualPrediction.getOutput().getOutputs().stream().filter(f -> f.getName().equals("output1")).findFirst();
    assertTrue(oOutput1.isPresent());
    Output output1 = oOutput1.get();
    assertEquals(Type.NUMBER, output1.getType());
    assertEquals(20, output1.getValue().asNumber());
    assertTrue(counterfactualPrediction.getInput().getFeatures().isEmpty());
    assertEquals(counterfactualPrediction.getMaxRunningTimeSeconds(), request.getMaxRunningTimeSeconds());
}
Also used : CounterfactualExplainabilityRequest(org.kie.kogito.explainability.api.CounterfactualExplainabilityRequest) BeforeEach(org.junit.jupiter.api.BeforeEach) BaseExplainabilityRequest(org.kie.kogito.explainability.api.BaseExplainabilityRequest) Feature(org.kie.kogito.explainability.model.Feature) ArgumentMatchers.eq(org.mockito.ArgumentMatchers.eq) CounterfactualDomainRange(org.kie.kogito.explainability.api.CounterfactualDomainRange) CounterfactualEntity(org.kie.kogito.explainability.local.counterfactual.entities.CounterfactualEntity) Value(org.kie.kogito.explainability.model.Value) CounterfactualResult(org.kie.kogito.explainability.local.counterfactual.CounterfactualResult) Assertions.assertFalse(org.junit.jupiter.api.Assertions.assertFalse) Map(java.util.Map) EmptyFeatureDomain(org.kie.kogito.explainability.model.domain.EmptyFeatureDomain) BaseExplainabilityResult(org.kie.kogito.explainability.api.BaseExplainabilityResult) PredictionOutput(org.kie.kogito.explainability.model.PredictionOutput) CounterfactualExplainabilityRequest(org.kie.kogito.explainability.api.CounterfactualExplainabilityRequest) NamedTypedValue(org.kie.kogito.explainability.api.NamedTypedValue) ExplainabilityStatus(org.kie.kogito.explainability.api.ExplainabilityStatus) UUID(java.util.UUID) Collectors(java.util.stream.Collectors) TypedValue(org.kie.kogito.tracing.typedvalue.TypedValue) UnitValue(org.kie.kogito.tracing.typedvalue.UnitValue) Test(org.junit.jupiter.api.Test) List(java.util.List) CollectionValue(org.kie.kogito.tracing.typedvalue.CollectionValue) Output(org.kie.kogito.explainability.model.Output) Assertions.assertTrue(org.junit.jupiter.api.Assertions.assertTrue) Optional(java.util.Optional) CounterfactualExplainabilityResult(org.kie.kogito.explainability.api.CounterfactualExplainabilityResult) CounterfactualSearchDomain(org.kie.kogito.explainability.api.CounterfactualSearchDomain) Mockito.mock(org.mockito.Mockito.mock) Assertions.assertThrows(org.junit.jupiter.api.Assertions.assertThrows) ArgumentMatchers.any(org.mockito.ArgumentMatchers.any) IntNode(com.fasterxml.jackson.databind.node.IntNode) Prediction(org.kie.kogito.explainability.model.Prediction) StructureValue(org.kie.kogito.tracing.typedvalue.StructureValue) CounterfactualSearchDomainStructureValue(org.kie.kogito.explainability.api.CounterfactualSearchDomainStructureValue) CounterfactualPrediction(org.kie.kogito.explainability.model.CounterfactualPrediction) PredictionProviderFactory(org.kie.kogito.explainability.PredictionProviderFactory) Assertions.assertEquals(org.junit.jupiter.api.Assertions.assertEquals) CounterfactualSearchDomainCollectionValue(org.kie.kogito.explainability.api.CounterfactualSearchDomainCollectionValue) CounterfactualExplainer(org.kie.kogito.explainability.local.counterfactual.CounterfactualExplainer) Type(org.kie.kogito.explainability.model.Type) PredictionProvider(org.kie.kogito.explainability.model.PredictionProvider) Mockito.verify(org.mockito.Mockito.verify) Consumer(java.util.function.Consumer) DoubleEntity(org.kie.kogito.explainability.local.counterfactual.entities.DoubleEntity) CounterfactualSearchDomainUnitValue(org.kie.kogito.explainability.api.CounterfactualSearchDomainUnitValue) DoubleNode(com.fasterxml.jackson.databind.node.DoubleNode) NumericalFeatureDomain(org.kie.kogito.explainability.model.domain.NumericalFeatureDomain) BooleanNode(com.fasterxml.jackson.databind.node.BooleanNode) Collections(java.util.Collections) ModelIdentifier(org.kie.kogito.explainability.api.ModelIdentifier) NamedTypedValue(org.kie.kogito.explainability.api.NamedTypedValue) IntNode(com.fasterxml.jackson.databind.node.IntNode) Prediction(org.kie.kogito.explainability.model.Prediction) CounterfactualPrediction(org.kie.kogito.explainability.model.CounterfactualPrediction) PredictionOutput(org.kie.kogito.explainability.model.PredictionOutput) Output(org.kie.kogito.explainability.model.Output) UnitValue(org.kie.kogito.tracing.typedvalue.UnitValue) CounterfactualSearchDomainUnitValue(org.kie.kogito.explainability.api.CounterfactualSearchDomainUnitValue) CounterfactualPrediction(org.kie.kogito.explainability.model.CounterfactualPrediction) Test(org.junit.jupiter.api.Test)

Example 20 with UnitValue

use of org.kie.kogito.tracing.typedvalue.UnitValue in project kogito-apps by kiegroup.

the class CounterfactualExplainerServiceHandlerTest method testGetPredictionWithCollectionInputModel.

@Test
public void testGetPredictionWithCollectionInputModel() {
    CounterfactualExplainabilityRequest request = new CounterfactualExplainabilityRequest(EXECUTION_ID, SERVICE_URL, MODEL_IDENTIFIER, COUNTERFACTUAL_ID, List.of(new NamedTypedValue("input1", new CollectionValue("number", List.of(new UnitValue("number", new IntNode(100)))))), Collections.emptyList(), Collections.emptyList(), MAX_RUNNING_TIME_SECONDS);
    assertThrows(IllegalArgumentException.class, () -> handler.getPrediction(request));
}
Also used : CounterfactualExplainabilityRequest(org.kie.kogito.explainability.api.CounterfactualExplainabilityRequest) NamedTypedValue(org.kie.kogito.explainability.api.NamedTypedValue) CollectionValue(org.kie.kogito.tracing.typedvalue.CollectionValue) CounterfactualSearchDomainCollectionValue(org.kie.kogito.explainability.api.CounterfactualSearchDomainCollectionValue) IntNode(com.fasterxml.jackson.databind.node.IntNode) UnitValue(org.kie.kogito.tracing.typedvalue.UnitValue) CounterfactualSearchDomainUnitValue(org.kie.kogito.explainability.api.CounterfactualSearchDomainUnitValue) Test(org.junit.jupiter.api.Test)

Aggregations

UnitValue (org.kie.kogito.tracing.typedvalue.UnitValue)31 Test (org.junit.jupiter.api.Test)24 NamedTypedValue (org.kie.kogito.explainability.api.NamedTypedValue)23 IntNode (com.fasterxml.jackson.databind.node.IntNode)18 CounterfactualSearchDomainUnitValue (org.kie.kogito.explainability.api.CounterfactualSearchDomainUnitValue)17 CounterfactualExplainabilityRequest (org.kie.kogito.explainability.api.CounterfactualExplainabilityRequest)14 StructureValue (org.kie.kogito.tracing.typedvalue.StructureValue)12 List (java.util.List)10 CollectionValue (org.kie.kogito.tracing.typedvalue.CollectionValue)10 CounterfactualSearchDomain (org.kie.kogito.explainability.api.CounterfactualSearchDomain)9 ArrayList (java.util.ArrayList)8 CounterfactualExplainabilityResult (org.kie.kogito.explainability.api.CounterfactualExplainabilityResult)8 CounterfactualSearchDomainStructureValue (org.kie.kogito.explainability.api.CounterfactualSearchDomainStructureValue)8 TypedValue (org.kie.kogito.tracing.typedvalue.TypedValue)8 Decision (org.kie.kogito.trusty.storage.api.model.decision.Decision)8 Collections (java.util.Collections)7 Assertions.assertEquals (org.junit.jupiter.api.Assertions.assertEquals)7 Assertions.assertTrue (org.junit.jupiter.api.Assertions.assertTrue)7 CounterfactualDomainRange (org.kie.kogito.explainability.api.CounterfactualDomainRange)7 ExplainabilityStatus (org.kie.kogito.explainability.api.ExplainabilityStatus)7