use of io.seldon.protos.PredictionProtos.SeldonMessage in project seldon-core by SeldonIO.
the class AverageCombinerUnit method aggregate.
@Override
public SeldonMessage aggregate(List<SeldonMessage> outputs, PredictiveUnitState state) {
if (outputs.size() == 0) {
throw new APIException(APIException.ApiExceptionType.ENGINE_INVALID_COMBINER_RESPONSE, String.format("Combiner received no inputs"));
}
int[] shape = PredictorUtils.getShape(outputs.get(0).getData());
if (shape == null) {
throw new APIException(APIException.ApiExceptionType.ENGINE_INVALID_COMBINER_RESPONSE, String.format("Combiner cannot extract data shape"));
}
if (shape.length != 2) {
throw new APIException(APIException.ApiExceptionType.ENGINE_INVALID_COMBINER_RESPONSE, String.format("Combiner received data that is not 2 dimensional"));
}
INDArray currentSum = Nd4j.zeros(shape[0], shape[1]);
SeldonMessage.Builder respBuilder = SeldonMessage.newBuilder();
for (Iterator<SeldonMessage> i = outputs.iterator(); i.hasNext(); ) {
DefaultData inputData = i.next().getData();
int[] inputShape = PredictorUtils.getShape(inputData);
if (inputShape == null) {
throw new APIException(APIException.ApiExceptionType.ENGINE_INVALID_COMBINER_RESPONSE, String.format("Combiner cannot extract data shape"));
}
if (inputShape.length != 2) {
throw new APIException(APIException.ApiExceptionType.ENGINE_INVALID_COMBINER_RESPONSE, String.format("Combiner received data that is not 2 dimensional"));
}
if (inputShape[0] != shape[0]) {
throw new APIException(APIException.ApiExceptionType.ENGINE_INVALID_COMBINER_RESPONSE, String.format("Expected batch length %d but found %d", shape[0], inputShape[0]));
}
if (inputShape[1] != shape[1]) {
throw new APIException(APIException.ApiExceptionType.ENGINE_INVALID_COMBINER_RESPONSE, String.format("Expected batch length %d but found %d", shape[1], inputShape[1]));
}
INDArray inputArr = PredictorUtils.getINDArray(inputData);
currentSum = currentSum.add(inputArr);
}
currentSum = currentSum.div((float) outputs.size());
DefaultData newData = PredictorUtils.updateData(outputs.get(0).getData(), currentSum);
respBuilder.setData(newData);
respBuilder.setMeta(outputs.get(0).getMeta());
respBuilder.setStatus(outputs.get(0).getStatus());
return respBuilder.build();
}
use of io.seldon.protos.PredictionProtos.SeldonMessage in project seldon-core by SeldonIO.
the class PredictiveUnitBean method getOutput.
public SeldonMessage getOutput(SeldonMessage request, PredictiveUnitState state) throws InterruptedException, ExecutionException, InvalidProtocolBufferException {
Map<String, Integer> routingDict = new HashMap<String, Integer>();
SeldonMessage response = getOutputAsync(request, state, routingDict).get();
SeldonMessage.Builder builder = SeldonMessage.newBuilder(response).setMeta(Meta.newBuilder(response.getMeta()).putAllRouting(routingDict));
return builder.build();
}
use of io.seldon.protos.PredictionProtos.SeldonMessage in project seldon-core by SeldonIO.
the class PredictiveUnitBean method getOutputAsync.
@Async
private Future<SeldonMessage> getOutputAsync(SeldonMessage input, PredictiveUnitState state, Map<String, Integer> routingDict) throws InterruptedException, ExecutionException, InvalidProtocolBufferException {
// Getting the actual implementation (microservice or hardcoded? )
PredictiveUnitImpl implementation = predictorConfig.getImplementation(state);
if (implementation == null) {
implementation = this;
}
// Compute the transformed Input
SeldonMessage transformedInput = implementation.transformInput(input, state);
// Preserve the original metadata
transformedInput = mergeMeta(transformedInput, input.getMeta());
if (state.children.isEmpty()) {
// If this unit has no children, the transformed input becomes the output
return new AsyncResult<>(transformedInput);
}
List<PredictiveUnitState> selectedChildren = new ArrayList<PredictiveUnitState>();
List<Future<SeldonMessage>> deferredChildrenOutputs = new ArrayList<Future<SeldonMessage>>();
List<SeldonMessage> childrenOutputs = new ArrayList<SeldonMessage>();
// Get the routing. -1 means all children
int routing = implementation.route(transformedInput, state);
sanityCheckRouting(routing, state);
// Update the routing dictionary
routingDict.put(state.name, routing);
if (routing == -1) {
// No routing, propagate to all children
selectedChildren = state.children;
} else {
// Propagate to selected child only
selectedChildren.add(state.children.get(routing));
}
// Get all the children outputs asynchronously
for (PredictiveUnitState childState : selectedChildren) {
deferredChildrenOutputs.add(getOutputAsync(transformedInput, childState, routingDict));
}
for (Future<SeldonMessage> deferredOutput : deferredChildrenOutputs) {
childrenOutputs.add(deferredOutput.get());
}
// Compute the backward transformation of all children outputs
SeldonMessage aggregatedOutput = implementation.aggregate(childrenOutputs, state);
// Merge all the outputs metadata
aggregatedOutput = mergeMeta(aggregatedOutput, childrenOutputs);
SeldonMessage transformedOutput = implementation.transformOutput(aggregatedOutput, state);
// Preserve metadata
transformedOutput = mergeMeta(transformedOutput, aggregatedOutput.getMeta());
return new AsyncResult<>(transformedOutput);
}
Aggregations