Search in sources :

Example 6 with ZeroInitScheme

use of org.nd4j.weightinit.impl.ZeroInitScheme in project nd4j by deeplearning4j.

the class SameDiff method generateOutputVariableForOp.

/**
 * Generate the variables based on the given input op
 * and return the output variable names.
 *
 * @param function the function to generate the output
 *                 variable names for
 * @return the set of names generated for each output of the function.
 */
public SDVariable[] generateOutputVariableForOp(DifferentialFunction function, String baseName) {
    // if there is already a base name defined, use that
    if (baseName == null || baseName.isEmpty() && getBaseNameForFunction(function) != null)
        baseName = getBaseNameForFunction(function);
    if (baseName == null)
        baseName = function.opName();
    val outputShape = function.calculateOutputShape();
    if (outputShape == null || outputShape.isEmpty()) {
        if (function instanceof CustomOp) {
            CustomOp customOp = (CustomOp) function;
            val descriptor = customOp.getDescriptor();
            // can't guess number of outputs, variable
            if (descriptor == null || descriptor.getNumOutputs() <= 0) {
                throw new ND4JIllegalStateException("No output variables found!");
            } else {
                char ordering = 'c';
                if (function.args()[0].getArr() != null) {
                    ordering = function.args()[0].getArr().ordering();
                }
                SDVariable[] ret = new SDVariable[descriptor.getNumOutputs()];
                // dynamic shapes
                for (int i = 0; i < ret.length; i++) {
                    SDVariable checkGet = getVariable(baseName);
                    if (checkGet == null) {
                        checkGet = var(generateNewVarName(baseName, i), null, new ZeroInitScheme(ordering));
                    } else if (i > 0 && !importedVarName.contains(baseName)) {
                        // need to find a new name
                        String newName = generateNewVarName(baseName, i);
                        checkGet = getVariable(newName);
                    }
                    if (checkGet == null) {
                        String newName = generateNewVarName(baseName, i);
                        checkGet = var(newName, null, new ZeroInitScheme(ordering));
                    }
                    ret[i] = checkGet;
                }
                return ret;
            }
        } else // this is for unresolved shapes, we know xyz is always 1 output
        if (function instanceof BaseOp && outputShape.isEmpty()) {
            SDVariable[] ret = new SDVariable[1];
            SDVariable checkGet = getVariable(baseName);
            char ordering = 'c';
            if (function.args()[0].getArr() != null) {
                ordering = function.args()[0].getArr().ordering();
            }
            if (checkGet == null) {
                checkGet = var(baseName, null, new ZeroInitScheme(ordering));
            } else if (!importedVarName.contains(baseName)) {
                // need to find a new name
                String newName = generateNewVarName(baseName, 0);
                checkGet = var(newName, null, new ZeroInitScheme(ordering));
            }
            if (checkGet == null) {
                checkGet = var(baseName, null, new ZeroInitScheme(ordering));
            }
            ret[0] = checkGet;
            return ret;
        }
    }
    char ordering = 'c';
    if (function.args()[0].getArr() != null) {
        ordering = function.args()[0].getArr().ordering();
    }
    SDVariable[] ret = new SDVariable[outputShape.size()];
    // ownName/baseName will be used to get variables names
    val ownName = function.getOwnName();
    val rootName = baseName;
    for (int i = 0; i < ret.length; i++) {
        val shape = outputShape.get(i);
        // it should be: rootName:index. i.e.: split:1, split:2, split:3, split:4 etc
        baseName = rootName + (i > 0 ? ":" + i : "");
        SDVariable checkGet = getVariable(baseName);
        if (checkGet == null) {
            // obviously - there's no such var, just add it
            checkGet = var(baseName, shape, new ZeroInitScheme(ordering));
        } else if (shape != null && !shapeAlreadyExistsForVarName(checkGet.getVarName())) {
            // var exists, let's update its shape
            putShapeForVarName(checkGet.getVarName(), shape);
        } else if (shape != null && shapeAlreadyExistsForVarName(checkGet.getVarName())) {
        // no-op.
        // TODO: maybe we should check shapes equality here?
        // it's either var that already exist, or something bad happening
        } else if (!importedVarName.contains(baseName)) {
            // FIXME: dead end.  it's impossible to get here with null as shape
            // need to find a new name
            int count = 1;
            String name = baseName + "_" + count + (i > 0 ? ":" + i : "");
            while (getVariable(name) != null) {
                count++;
                name = baseName + "_" + count + (i > 0 ? ":" + i : "");
            }
            if (getVariable(name) != null) {
                throw new ND4JIllegalStateException("Converged on already generated variable!");
            }
            checkGet = var(name, shape, new ZeroInitScheme(ordering));
        }
        if (checkGet == null) {
            checkGet = var(baseName + (i > 0 ? ":" + i : ""), shape, new ZeroInitScheme(ordering));
        }
        ret[i] = checkGet;
    }
    return ret;
}
Also used : ZeroInitScheme(org.nd4j.weightinit.impl.ZeroInitScheme) ND4JIllegalStateException(org.nd4j.linalg.exception.ND4JIllegalStateException)

Example 7 with ZeroInitScheme

use of org.nd4j.weightinit.impl.ZeroInitScheme in project nd4j by deeplearning4j.

the class BaseGraphMapper method importGraph.

/**
 * This method converts given TF
 * @param tfGraph
 * @return
 */
@Override
public SameDiff importGraph(GRAPH_TYPE tfGraph) {
    SameDiff diff = SameDiff.create();
    ImportState<GRAPH_TYPE, TENSOR_TYPE> importState = new ImportState<>();
    importState.setSameDiff(diff);
    importState.setGraph(tfGraph);
    val variablesForGraph = variablesForGraph(tfGraph);
    importState.setVariables(variablesForGraph);
    // for each variable
    for (Map.Entry<String, TENSOR_TYPE> entry : variablesForGraph.entrySet()) {
        if (dataTypeForTensor(entry.getValue()) == DataBuffer.Type.UNKNOWN) {
            val var = importState.getSameDiff().var(entry.getKey(), null, new ZeroInitScheme('c'));
            // mark as place holder for validating resolution later.
            if (isPlaceHolder(entry.getValue())) {
                importState.getSameDiff().addAsPlaceHolder(var.getVarName());
                if (var.getShape() != null)
                    importState.getSameDiff().setOriginalPlaceHolderShape(var.getVarName(), var.getShape());
            }
            continue;
        }
        val arr = getNDArrayFromTensor(entry.getKey(), entry.getValue(), tfGraph);
        if (arr != null) {
            val var = importState.getSameDiff().var(entry.getKey(), arr);
            // ensure the array is made available for later processing
            diff.associateArrayWithVariable(arr, var);
        } else if (getShapeFromTensor(entry.getValue()) == null) {
            val var = importState.getSameDiff().var(entry.getKey(), null, new ZeroInitScheme('c'));
            // that it isn't  a place holder.
            if (isPlaceHolder(entry.getValue())) {
                val originalShape = getShapeFromTensor(entry.getValue());
                importState.getSameDiff().addAsPlaceHolder(var.getVarName());
                if (var.getShape() != null)
                    importState.getSameDiff().setOriginalPlaceHolderShape(var.getVarName(), originalShape);
            }
        } else {
            val originalShape = getShapeFromTensor(entry.getValue());
            val var = importState.getSameDiff().var(entry.getKey(), originalShape);
            // that it isn't  a place holder.
            if (isPlaceHolder(entry.getValue())) {
                importState.getSameDiff().addAsPlaceHolder(var.getVarName());
                importState.getSameDiff().setOriginalPlaceHolderShape(var.getVarName(), originalShape);
            }
        }
    }
    // setup vertex ids for  names
    // handle mapping vertex ids properly
    val tfNodesList = getNodeList(tfGraph);
    for (NODE_TYPE tfNode : tfNodesList) {
        if (!opsToIgnore().contains(getOpType(tfNode)) || isOpIgnoreException(tfNode))
            mapNodeType(tfNode, importState);
    }
    return diff;
}
Also used : lombok.val(lombok.val) ZeroInitScheme(org.nd4j.weightinit.impl.ZeroInitScheme) SameDiff(org.nd4j.autodiff.samediff.SameDiff) HashMap(java.util.HashMap) LinkedHashMap(java.util.LinkedHashMap) Map(java.util.Map)

Aggregations

ZeroInitScheme (org.nd4j.weightinit.impl.ZeroInitScheme)7 lombok.val (lombok.val)2 Test (org.junit.Test)2 SameDiff (org.nd4j.autodiff.samediff.SameDiff)2 Linear (org.nd4j.linalg.api.ops.impl.layers.Linear)2 ND4JIllegalStateException (org.nd4j.linalg.exception.ND4JIllegalStateException)2 HashMap (java.util.HashMap)1 LinkedHashMap (java.util.LinkedHashMap)1 Map (java.util.Map)1 SDVariable (org.nd4j.autodiff.samediff.SDVariable)1 INDArray (org.nd4j.linalg.api.ndarray.INDArray)1 ND4JIllegalArgumentException (org.nd4j.linalg.exception.ND4JIllegalArgumentException)1 OneInitScheme (org.nd4j.weightinit.impl.OneInitScheme)1 UniformInitScheme (org.nd4j.weightinit.impl.UniformInitScheme)1