use of org.nd4j.graph.FlatArray in project nd4j by deeplearning4j.
the class NativeGraphExecutioner method executeGraph.
/**
* This method executes given graph and returns results
*
* @param sd
* @return
*/
@Override
public INDArray[] executeGraph(SameDiff sd, ExecutorConfiguration configuration) {
Map<Integer, Node> intermediate = new HashMap<>();
ByteBuffer buffer = convertToFlatBuffers(sd, configuration, intermediate);
BytePointer bPtr = new BytePointer(buffer);
log.info("Buffer length: {}", buffer.limit());
Pointer res = NativeOpsHolder.getInstance().getDeviceNativeOps().executeFlatGraphFloat(null, bPtr);
if (res == null)
throw new ND4JIllegalStateException("Graph execution failed");
// FIXME: this is BAD
PagedPointer pagedPointer = new PagedPointer(res, 1024 * 1024L);
FlatResult fr = FlatResult.getRootAsFlatResult(pagedPointer.asBytePointer().asByteBuffer());
log.info("VarMap: {}", sd.variableMap());
INDArray[] results = new INDArray[fr.variablesLength()];
for (int e = 0; e < fr.variablesLength(); e++) {
FlatVariable var = fr.variables(e);
log.info("Var received: id: [{}:{}/<{}>];", var.id().first(), var.id().second(), var.name());
FlatArray ndarray = var.ndarray();
INDArray val = Nd4j.createFromFlatArray(ndarray);
results[e] = val;
if (var.name() != null && sd.variableMap().containsKey(var.name())) {
// log.info("VarName: {}; Exists: {}; NDArrayInfo: {};", var.opName(), sd.variableMap().containsKey(var.opName()), sd.getVertexToArray().containsKey(var.opName()));
// log.info("storing: {}; array: {}", var.name(), val);
sd.associateArrayWithVariable(val, sd.variableMap().get(var.name()));
} else {
// log.info("Original id: {}; out: {}; out2: {}", original, sd.getVertexIdxToInfo().get(original), graph.getVariableForVertex(original));
if (sd.variableMap().get(var.name()) != null) {
sd.associateArrayWithVariable(val, sd.getVariable(var.name()));
} else {
throw new ND4JIllegalStateException("Unknown variable received as result: [" + var.name() + "]");
}
}
}
return results;
}
Aggregations