Search in sources :

Example 26 with SeldonMessage

use of io.seldon.protos.PredictionProtos.SeldonMessage in project seldon-core by SeldonIO.

the class AverageCombinerUnit method aggregate.

@Override
public SeldonMessage aggregate(List<SeldonMessage> outputs, PredictiveUnitState state) {
    if (outputs.size() == 0) {
        throw new APIException(APIException.ApiExceptionType.ENGINE_INVALID_COMBINER_RESPONSE, String.format("Combiner received no inputs"));
    }
    int[] shape = PredictorUtils.getShape(outputs.get(0).getData());
    if (shape == null) {
        throw new APIException(APIException.ApiExceptionType.ENGINE_INVALID_COMBINER_RESPONSE, String.format("Combiner cannot extract data shape"));
    }
    if (shape.length != 2) {
        throw new APIException(APIException.ApiExceptionType.ENGINE_INVALID_COMBINER_RESPONSE, String.format("Combiner received data that is not 2 dimensional"));
    }
    INDArray currentSum = Nd4j.zeros(shape[0], shape[1]);
    SeldonMessage.Builder respBuilder = SeldonMessage.newBuilder();
    for (Iterator<SeldonMessage> i = outputs.iterator(); i.hasNext(); ) {
        DefaultData inputData = i.next().getData();
        int[] inputShape = PredictorUtils.getShape(inputData);
        if (inputShape == null) {
            throw new APIException(APIException.ApiExceptionType.ENGINE_INVALID_COMBINER_RESPONSE, String.format("Combiner cannot extract data shape"));
        }
        if (inputShape.length != 2) {
            throw new APIException(APIException.ApiExceptionType.ENGINE_INVALID_COMBINER_RESPONSE, String.format("Combiner received data that is not 2 dimensional"));
        }
        if (inputShape[0] != shape[0]) {
            throw new APIException(APIException.ApiExceptionType.ENGINE_INVALID_COMBINER_RESPONSE, String.format("Expected batch length %d but found %d", shape[0], inputShape[0]));
        }
        if (inputShape[1] != shape[1]) {
            throw new APIException(APIException.ApiExceptionType.ENGINE_INVALID_COMBINER_RESPONSE, String.format("Expected batch length %d but found %d", shape[1], inputShape[1]));
        }
        INDArray inputArr = PredictorUtils.getINDArray(inputData);
        currentSum = currentSum.add(inputArr);
    }
    currentSum = currentSum.div((float) outputs.size());
    DefaultData newData = PredictorUtils.updateData(outputs.get(0).getData(), currentSum);
    respBuilder.setData(newData);
    respBuilder.setMeta(outputs.get(0).getMeta());
    respBuilder.setStatus(outputs.get(0).getStatus());
    return respBuilder.build();
}
Also used : APIException(io.seldon.engine.exception.APIException) SeldonMessage(io.seldon.protos.PredictionProtos.SeldonMessage) INDArray(org.nd4j.linalg.api.ndarray.INDArray) DefaultData(io.seldon.protos.PredictionProtos.DefaultData)

Example 27 with SeldonMessage

use of io.seldon.protos.PredictionProtos.SeldonMessage in project seldon-core by SeldonIO.

the class PredictiveUnitBean method getOutput.

public SeldonMessage getOutput(SeldonMessage request, PredictiveUnitState state) throws InterruptedException, ExecutionException, InvalidProtocolBufferException {
    Map<String, Integer> routingDict = new HashMap<String, Integer>();
    SeldonMessage response = getOutputAsync(request, state, routingDict).get();
    SeldonMessage.Builder builder = SeldonMessage.newBuilder(response).setMeta(Meta.newBuilder(response.getMeta()).putAllRouting(routingDict));
    return builder.build();
}
Also used : SeldonMessage(io.seldon.protos.PredictionProtos.SeldonMessage) HashMap(java.util.HashMap)

Example 28 with SeldonMessage

use of io.seldon.protos.PredictionProtos.SeldonMessage in project seldon-core by SeldonIO.

the class PredictiveUnitBean method getOutputAsync.

@Async
private Future<SeldonMessage> getOutputAsync(SeldonMessage input, PredictiveUnitState state, Map<String, Integer> routingDict) throws InterruptedException, ExecutionException, InvalidProtocolBufferException {
    // Getting the actual implementation (microservice or hardcoded? )
    PredictiveUnitImpl implementation = predictorConfig.getImplementation(state);
    if (implementation == null) {
        implementation = this;
    }
    // Compute the transformed Input
    SeldonMessage transformedInput = implementation.transformInput(input, state);
    // Preserve the original metadata
    transformedInput = mergeMeta(transformedInput, input.getMeta());
    if (state.children.isEmpty()) {
        // If this unit has no children, the transformed input becomes the output
        return new AsyncResult<>(transformedInput);
    }
    List<PredictiveUnitState> selectedChildren = new ArrayList<PredictiveUnitState>();
    List<Future<SeldonMessage>> deferredChildrenOutputs = new ArrayList<Future<SeldonMessage>>();
    List<SeldonMessage> childrenOutputs = new ArrayList<SeldonMessage>();
    // Get the routing. -1 means all children
    int routing = implementation.route(transformedInput, state);
    sanityCheckRouting(routing, state);
    // Update the routing dictionary
    routingDict.put(state.name, routing);
    if (routing == -1) {
        // No routing, propagate to all children
        selectedChildren = state.children;
    } else {
        // Propagate to selected child only
        selectedChildren.add(state.children.get(routing));
    }
    // Get all the children outputs asynchronously
    for (PredictiveUnitState childState : selectedChildren) {
        deferredChildrenOutputs.add(getOutputAsync(transformedInput, childState, routingDict));
    }
    for (Future<SeldonMessage> deferredOutput : deferredChildrenOutputs) {
        childrenOutputs.add(deferredOutput.get());
    }
    // Compute the backward transformation of all children outputs
    SeldonMessage aggregatedOutput = implementation.aggregate(childrenOutputs, state);
    // Merge all the outputs metadata
    aggregatedOutput = mergeMeta(aggregatedOutput, childrenOutputs);
    SeldonMessage transformedOutput = implementation.transformOutput(aggregatedOutput, state);
    // Preserve metadata
    transformedOutput = mergeMeta(transformedOutput, aggregatedOutput.getMeta());
    return new AsyncResult<>(transformedOutput);
}
Also used : SeldonMessage(io.seldon.protos.PredictionProtos.SeldonMessage) ArrayList(java.util.ArrayList) Future(java.util.concurrent.Future) AsyncResult(org.springframework.scheduling.annotation.AsyncResult) Async(org.springframework.scheduling.annotation.Async)

Aggregations

SeldonMessage (io.seldon.protos.PredictionProtos.SeldonMessage)28 Test (org.junit.Test)18 ByteString (com.google.protobuf.ByteString)7 ArrayList (java.util.ArrayList)5 PredictiveUnit (io.seldon.protos.DeploymentProtos.PredictiveUnit)4 PredictorSpec (io.seldon.protos.DeploymentProtos.PredictorSpec)4 SpringBootTest (org.springframework.boot.test.context.SpringBootTest)4 InvalidProtocolBufferException (com.google.protobuf.InvalidProtocolBufferException)3 APIException (io.seldon.engine.exception.APIException)3 HashMap (java.util.HashMap)3 PodTemplateSpec (io.kubernetes.client.proto.V1.PodTemplateSpec)2 DefaultData (io.seldon.protos.PredictionProtos.DefaultData)2 HttpHeaders (org.springframework.http.HttpHeaders)2 ResponseEntity (org.springframework.http.ResponseEntity)2 RequestMapping (org.springframework.web.bind.annotation.RequestMapping)2 Value (com.google.protobuf.Value)1 SeldonAPIException (io.seldon.apife.exception.SeldonAPIException)1 PredictorState (io.seldon.engine.predictors.PredictorState)1 Parameter (io.seldon.protos.DeploymentProtos.Parameter)1 IOException (java.io.IOException)1