Search in sources :

Example 1 with FlatResult

use of org.nd4j.graph.FlatResult in project nd4j by deeplearning4j.

the class NativeGraphExecutioner method executeGraph.

/**
 * This method executes given graph and returns results
 *
 * @param sd
 * @return
 */
@Override
public INDArray[] executeGraph(SameDiff sd, ExecutorConfiguration configuration) {
    Map<Integer, Node> intermediate = new HashMap<>();
    ByteBuffer buffer = convertToFlatBuffers(sd, configuration, intermediate);
    BytePointer bPtr = new BytePointer(buffer);
    log.info("Buffer length: {}", buffer.limit());
    Pointer res = NativeOpsHolder.getInstance().getDeviceNativeOps().executeFlatGraphFloat(null, bPtr);
    if (res == null)
        throw new ND4JIllegalStateException("Graph execution failed");
    // FIXME: this is BAD
    PagedPointer pagedPointer = new PagedPointer(res, 1024 * 1024L);
    FlatResult fr = FlatResult.getRootAsFlatResult(pagedPointer.asBytePointer().asByteBuffer());
    log.info("VarMap: {}", sd.variableMap());
    INDArray[] results = new INDArray[fr.variablesLength()];
    for (int e = 0; e < fr.variablesLength(); e++) {
        FlatVariable var = fr.variables(e);
        log.info("Var received: id: [{}:{}/<{}>];", var.id().first(), var.id().second(), var.name());
        FlatArray ndarray = var.ndarray();
        INDArray val = Nd4j.createFromFlatArray(ndarray);
        results[e] = val;
        if (var.name() != null && sd.variableMap().containsKey(var.name())) {
            // log.info("VarName: {}; Exists: {}; NDArrayInfo: {};", var.opName(), sd.variableMap().containsKey(var.opName()), sd.getVertexToArray().containsKey(var.opName()));
            // log.info("storing: {}; array: {}", var.name(), val);
            sd.associateArrayWithVariable(val, sd.variableMap().get(var.name()));
        } else {
            // log.info("Original id: {}; out: {}; out2: {}", original, sd.getVertexIdxToInfo().get(original), graph.getVariableForVertex(original));
            if (sd.variableMap().get(var.name()) != null) {
                sd.associateArrayWithVariable(val, sd.getVariable(var.name()));
            } else {
                throw new ND4JIllegalStateException("Unknown variable received as result: [" + var.name() + "]");
            }
        }
    }
    return results;
}
Also used : HashMap(java.util.HashMap) BytePointer(org.bytedeco.javacpp.BytePointer) BytePointer(org.bytedeco.javacpp.BytePointer) PagedPointer(org.nd4j.linalg.api.memory.pointers.PagedPointer) Pointer(org.bytedeco.javacpp.Pointer) ByteBuffer(java.nio.ByteBuffer) PagedPointer(org.nd4j.linalg.api.memory.pointers.PagedPointer) FlatArray(org.nd4j.graph.FlatArray) FlatVariable(org.nd4j.graph.FlatVariable) INDArray(org.nd4j.linalg.api.ndarray.INDArray) ND4JIllegalStateException(org.nd4j.linalg.exception.ND4JIllegalStateException) FlatResult(org.nd4j.graph.FlatResult)

Aggregations

ByteBuffer (java.nio.ByteBuffer)1 HashMap (java.util.HashMap)1 BytePointer (org.bytedeco.javacpp.BytePointer)1 Pointer (org.bytedeco.javacpp.Pointer)1 FlatArray (org.nd4j.graph.FlatArray)1 FlatResult (org.nd4j.graph.FlatResult)1 FlatVariable (org.nd4j.graph.FlatVariable)1 PagedPointer (org.nd4j.linalg.api.memory.pointers.PagedPointer)1 INDArray (org.nd4j.linalg.api.ndarray.INDArray)1 ND4JIllegalStateException (org.nd4j.linalg.exception.ND4JIllegalStateException)1