use of org.kie.kogito.tracing.typedvalue.UnitValue in project kogito-apps by kiegroup.
the class CounterfactualExplainabilityResultMarshallerTest method testWriteAndRead.
@Test
public void testWriteAndRead() throws IOException {
List<NamedTypedValue> inputs = List.of(new NamedTypedValue("unitIn", new UnitValue("number", "number", JsonNodeFactory.instance.numberNode(10))));
List<NamedTypedValue> outputs = List.of(new NamedTypedValue("unitOut", new UnitValue("number", "number", JsonNodeFactory.instance.numberNode(55))));
CounterfactualExplainabilityResult explainabilityResult = new CounterfactualExplainabilityResult("executionId", "counterfactualId", "solutionId", 0L, ExplainabilityStatus.SUCCEEDED, "statusDetail", true, CounterfactualExplainabilityResult.Stage.FINAL, inputs, outputs);
CounterfactualExplainabilityResultMarshaller marshaller = new CounterfactualExplainabilityResultMarshaller(new ObjectMapper());
marshaller.writeTo(writer, explainabilityResult);
CounterfactualExplainabilityResult retrieved = marshaller.readFrom(reader);
List<NamedTypedValue> retrievedInputs = List.of(retrieved.getInputs().toArray(new NamedTypedValue[0]));
List<NamedTypedValue> retrievedOutputs = List.of(retrieved.getOutputs().toArray(new NamedTypedValue[0]));
Assertions.assertEquals(explainabilityResult.getExecutionId(), retrieved.getExecutionId());
Assertions.assertEquals(explainabilityResult.getCounterfactualId(), retrieved.getCounterfactualId());
Assertions.assertEquals(explainabilityResult.getSolutionId(), retrieved.getSolutionId());
Assertions.assertEquals(explainabilityResult.getSequenceId(), retrieved.getSequenceId());
Assertions.assertEquals(explainabilityResult.getStatus(), retrieved.getStatus());
Assertions.assertEquals(explainabilityResult.getStatusDetails(), retrieved.getStatusDetails());
Assertions.assertEquals(explainabilityResult.getStage(), retrieved.getStage());
Assertions.assertEquals(1, retrievedInputs.size());
Assertions.assertEquals(inputs.get(0).getName(), retrievedInputs.get(0).getName());
Assertions.assertEquals(inputs.get(0).getValue().getKind(), retrievedInputs.get(0).getValue().getKind());
Assertions.assertEquals(inputs.get(0).getValue().getType(), retrievedInputs.get(0).getValue().getType());
Assertions.assertEquals(inputs.get(0).getValue().toUnit().getValue(), retrievedInputs.get(0).getValue().toUnit().getValue());
Assertions.assertEquals(1, retrievedOutputs.size());
Assertions.assertEquals(outputs.get(0).getName(), retrievedOutputs.get(0).getName());
Assertions.assertEquals(outputs.get(0).getValue().getKind(), retrievedOutputs.get(0).getValue().getKind());
Assertions.assertEquals(outputs.get(0).getValue().getType(), retrievedOutputs.get(0).getValue().getType());
Assertions.assertEquals(outputs.get(0).getValue().toUnit().getValue(), retrievedOutputs.get(0).getValue().toUnit().getValue());
}
use of org.kie.kogito.tracing.typedvalue.UnitValue in project kogito-apps by kiegroup.
the class CounterfactualExplainerServiceHandlerTest method testGetPredictionWithStructuredOutputModel.
@Test
public void testGetPredictionWithStructuredOutputModel() {
CounterfactualExplainabilityRequest request = new CounterfactualExplainabilityRequest(EXECUTION_ID, SERVICE_URL, MODEL_IDENTIFIER, COUNTERFACTUAL_ID, Collections.emptyList(), List.of(new NamedTypedValue("input1", new StructureValue("number", Map.of("input2b", new UnitValue("number", new IntNode(55)))))), Collections.emptyList(), MAX_RUNNING_TIME_SECONDS);
assertThrows(IllegalArgumentException.class, () -> handler.getPrediction(request));
}
use of org.kie.kogito.tracing.typedvalue.UnitValue in project kogito-apps by kiegroup.
the class CounterfactualExplainerServiceHandlerTest method testGetPredictionWithFlatInputModel.
@Test
public void testGetPredictionWithFlatInputModel() {
CounterfactualExplainabilityRequest request = new CounterfactualExplainabilityRequest(EXECUTION_ID, SERVICE_URL, MODEL_IDENTIFIER, COUNTERFACTUAL_ID, List.of(new NamedTypedValue("input1", new UnitValue("number", new IntNode(20)))), Collections.emptyList(), Collections.emptyList(), MAX_RUNNING_TIME_SECONDS);
Prediction prediction = handler.getPrediction(request);
assertTrue(prediction instanceof CounterfactualPrediction);
CounterfactualPrediction counterfactualPrediction = (CounterfactualPrediction) prediction;
assertEquals(1, counterfactualPrediction.getInput().getFeatures().size());
Optional<Feature> oInput1 = counterfactualPrediction.getInput().getFeatures().stream().filter(f -> f.getName().equals("input1")).findFirst();
assertTrue(oInput1.isPresent());
Feature input1 = oInput1.get();
assertEquals(Type.NUMBER, input1.getType());
assertEquals(20, input1.getValue().asNumber());
assertTrue(counterfactualPrediction.getInput().getFeatures().stream().allMatch(f -> f.getDomain().isEmpty()));
assertTrue(counterfactualPrediction.getInput().getFeatures().stream().allMatch(Feature::isConstrained));
assertTrue(counterfactualPrediction.getOutput().getOutputs().isEmpty());
assertEquals(counterfactualPrediction.getMaxRunningTimeSeconds(), request.getMaxRunningTimeSeconds());
}
use of org.kie.kogito.tracing.typedvalue.UnitValue in project kogito-apps by kiegroup.
the class CounterfactualExplainerServiceHandlerTest method testGetPredictionWithFlatOutputModel.
@Test
public void testGetPredictionWithFlatOutputModel() {
CounterfactualExplainabilityRequest request = new CounterfactualExplainabilityRequest(EXECUTION_ID, SERVICE_URL, MODEL_IDENTIFIER, COUNTERFACTUAL_ID, Collections.emptyList(), List.of(new NamedTypedValue("output1", new UnitValue("number", new IntNode(20)))), Collections.emptyList(), MAX_RUNNING_TIME_SECONDS);
Prediction prediction = handler.getPrediction(request);
assertTrue(prediction instanceof CounterfactualPrediction);
CounterfactualPrediction counterfactualPrediction = (CounterfactualPrediction) prediction;
assertEquals(1, counterfactualPrediction.getOutput().getOutputs().size());
Optional<Output> oOutput1 = counterfactualPrediction.getOutput().getOutputs().stream().filter(f -> f.getName().equals("output1")).findFirst();
assertTrue(oOutput1.isPresent());
Output output1 = oOutput1.get();
assertEquals(Type.NUMBER, output1.getType());
assertEquals(20, output1.getValue().asNumber());
assertTrue(counterfactualPrediction.getInput().getFeatures().isEmpty());
assertEquals(counterfactualPrediction.getMaxRunningTimeSeconds(), request.getMaxRunningTimeSeconds());
}
use of org.kie.kogito.tracing.typedvalue.UnitValue in project kogito-apps by kiegroup.
the class CounterfactualExplainerServiceHandlerTest method testGetPredictionWithCollectionInputModel.
@Test
public void testGetPredictionWithCollectionInputModel() {
CounterfactualExplainabilityRequest request = new CounterfactualExplainabilityRequest(EXECUTION_ID, SERVICE_URL, MODEL_IDENTIFIER, COUNTERFACTUAL_ID, List.of(new NamedTypedValue("input1", new CollectionValue("number", List.of(new UnitValue("number", new IntNode(100)))))), Collections.emptyList(), Collections.emptyList(), MAX_RUNNING_TIME_SECONDS);
assertThrows(IllegalArgumentException.class, () -> handler.getPrediction(request));
}
Aggregations