Search in sources :

Example 1 with ZeroInitScheme

use of org.nd4j.weightinit.impl.ZeroInitScheme in project nd4j by deeplearning4j.

the class TFGraphMapper method mapNodeType.

@Override
public void mapNodeType(NodeDef tfNode, ImportState<GraphDef, NodeDef> importState) {
    if (shouldSkip(tfNode) || alreadySeen(tfNode) || isVariableNode(tfNode)) {
        return;
    }
    val diff = importState.getSameDiff();
    if (isVariableNode(tfNode)) {
        List<Integer> dimensions = new ArrayList<>();
        Map<String, AttrValue> attributes = getAttrMap(tfNode);
        if (attributes.containsKey(VALUE_ATTR_KEY)) {
            diff.var(getName(tfNode), getArrayFrom(tfNode, importState.getGraph()));
        } else if (attributes.containsKey(SHAPE_KEY)) {
            AttrValue shape = attributes.get(SHAPE_KEY);
            int[] shapeArr = getShapeFromAttr(shape);
            int dims = shapeArr.length;
            if (dims > 0) {
                // even vector is 2d in nd4j
                if (dims == 1)
                    dimensions.add(1);
                for (int e = 0; e < dims; e++) {
                    // TODO: eventually we want long shapes :(
                    dimensions.add(getShapeFromAttr(shape)[e]);
                }
            }
        }
    } else if (isPlaceHolder(tfNode)) {
        val vertexId = diff.getVariable(getName(tfNode));
        diff.addAsPlaceHolder(vertexId.getVarName());
    } else {
        val opName = tfNode.getOp();
        val nodeName = tfNode.getName();
        // FIXME: early draft
        // conditional import
        /*
            if (nodeName.startsWith("cond") && nodeName.contains("/")) {
                val str = nodeName.replaceAll("/.*$","");
                importCondition(str, tfNode, importState);

                seenNodes.add(nodeName);
                return;
            } else if (nodeName.startsWith("while")) {
                // while loop import

                return;
            }
            */
        val differentialFunction = DifferentialFunctionClassHolder.getInstance().getOpWithTensorflowName(opName);
        if (differentialFunction == null) {
            throw new ND4JIllegalStateException("No tensorflow op found for " + opName + " possibly missing operation class?");
        }
        try {
            val newInstance = differentialFunction.getClass().newInstance();
            val args = new SDVariable[tfNode.getInputCount()];
            newInstance.setOwnName(tfNode.getName());
            for (int i = 0; i < tfNode.getInputCount(); i++) {
                val name = getNodeName(tfNode.getInput(i));
                args[i] = diff.getVariable(name);
                if (args[i] == null) {
                    args[i] = diff.var(name, null, new ZeroInitScheme('f'));
                    diff.addAsPlaceHolder(args[i].getVarName());
                }
                /**
                 * Note here that we are associating
                 * the output/result variable
                 * with its inputs and notifying
                 * the variable that it has a place holder argument
                 * it should resolve before trying to execute
                 * anything.
                 */
                if (diff.isPlaceHolder(args[i].getVarName())) {
                    diff.putPlaceHolderForVariable(args[i].getVarName(), name);
                }
            }
            diff.addArgsFor(args, newInstance);
            newInstance.setSameDiff(importState.getSameDiff());
            newInstance.initFromTensorFlow(tfNode, diff, getAttrMap(tfNode), importState.getGraph());
            mapProperties(newInstance, tfNode, importState.getGraph(), importState.getSameDiff(), newInstance.mappingsForFunction());
            importState.getSameDiff().putFunctionForId(newInstance.getOwnName(), newInstance);
            // ensure we can track node name to function instance later.
            diff.setBaseNameForFunctionInstanceId(tfNode.getName(), newInstance);
            diff.addVarNameForImport(tfNode.getName());
        } catch (Exception e) {
            log.error("Failed with [{}]", opName);
            throw new RuntimeException(e);
        }
    }
}
Also used : lombok.val(lombok.val) ZeroInitScheme(org.nd4j.weightinit.impl.ZeroInitScheme) ND4JIllegalStateException(org.nd4j.linalg.exception.ND4JIllegalStateException) ND4JIllegalStateException(org.nd4j.linalg.exception.ND4JIllegalStateException)

Example 2 with ZeroInitScheme

use of org.nd4j.weightinit.impl.ZeroInitScheme in project nd4j by deeplearning4j.

the class SameDiffTests method testLinearModule.

@Test
public void testLinearModule() {
    int nIn = 5;
    Linear linear = Linear.execBuilder().nIn(nIn).nOut(4).weightInitScheme(new UniformInitScheme('f', nIn)).biasWeightInitScheme(new ZeroInitScheme('f')).build();
    linear.exec(Nd4j.linspace(1, 20, 20).reshape(4, 5));
    assertEquals(1, linear.numOutputArguments());
}
Also used : UniformInitScheme(org.nd4j.weightinit.impl.UniformInitScheme) ZeroInitScheme(org.nd4j.weightinit.impl.ZeroInitScheme) Linear(org.nd4j.linalg.api.ops.impl.layers.Linear) Test(org.junit.Test)

Example 3 with ZeroInitScheme

use of org.nd4j.weightinit.impl.ZeroInitScheme in project nd4j by deeplearning4j.

the class While method doImport.

private void doImport(NodeDef nodeDef, SameDiff initWith, Map<String, AttrValue> attributesForNode, GraphDef graph, Set<String> skipSet, AtomicInteger currIndex) {
    val uniqueId = java.util.UUID.randomUUID().toString();
    skipSet.add(nodeDef.getName());
    val scopeCondition = SameDiff.create();
    val scopeLoop = SameDiff.create();
    initWith.putSubFunction("condition-" + uniqueId, scopeCondition);
    initWith.putSubFunction("loopbody-" + uniqueId, scopeLoop);
    this.loopBodyExecution = scopeLoop;
    this.predicateExecution = scopeCondition;
    this.startPosition = currIndex;
    log.info("Adding 2 new scopes for WHILE {}");
    val nodes = graph.getNodeList();
    for (; currIndex.get() < nodes.size(); currIndex.incrementAndGet()) {
        val tfNode = nodes.get(currIndex.get());
        if (!tfNode.getOp().equalsIgnoreCase("enter")) {
            // skipSet.add(tfNode.getName());
            break;
        }
        // if (skipSet.contains(tfNode.getName()))
        // continue;
        skipSet.add(tfNode.getName());
        val vars = new SDVariable[tfNode.getInputCount()];
        for (int e = 0; e < tfNode.getInputCount(); e++) {
            val input = TFGraphMapper.getInstance().getNodeName(tfNode.getInput(e));
            vars[e] = initWith.getVariable(input) == null ? initWith.var(input, null, new ZeroInitScheme()) : initWith.getVariable(input);
            scopeCondition.var(vars[e]);
            scopeLoop.var(vars[e]);
        }
        this.inputVars = vars;
    }
    // now we're skipping Merge step, since we've already captured variables at Enter step
    int mergedCnt = 0;
    for (; currIndex.get() < nodes.size(); currIndex.incrementAndGet()) {
        val tfNode = nodes.get(currIndex.get());
        if (!tfNode.getOp().equalsIgnoreCase("merge")) {
            scopeLoop.var(TFGraphMapper.getInstance().getNodeName(tfNode.getName()), null, new ZeroInitScheme());
            break;
        }
        skipSet.add(tfNode.getName());
        val var = scopeLoop.var(TFGraphMapper.getInstance().getNodeName(tfNode.getName()), null, new ZeroInitScheme());
        scopeCondition.var(var);
        initWith.var(var);
        mergedCnt++;
    }
    // now, we're adding conditional scope
    for (; currIndex.get() < nodes.size(); currIndex.incrementAndGet()) {
        val tfNode = nodes.get(currIndex.get());
        // we're parsing up to condition
        if (tfNode.getOp().equalsIgnoreCase("LoopCond")) {
            skipSet.add(tfNode.getName());
            currIndex.incrementAndGet();
            break;
        }
        boolean isConst = tfNode.getOp().equalsIgnoreCase("const");
        boolean isVar = tfNode.getOp().startsWith("VariableV");
        boolean isPlaceholder = tfNode.getOp().startsWith("Placeholder");
        if (isConst || isVar || isPlaceholder) {
            val var = scopeCondition.var(tfNode.getName(), null, new ZeroInitScheme());
            scopeLoop.var(var);
            initWith.var(var);
            log.info("Adding condition var [{}]", var.getVarName());
        } else if (!skipSet.contains(tfNode.getName())) {
            val func = DifferentialFunctionClassHolder.getInstance().getInstance(TFGraphMapper.getInstance().getMappedOp(tfNode.getOp()).opName());
            func.initFromTensorFlow(tfNode, scopeCondition, nodeDef.getAttrMap(), graph);
            func.setSameDiff(scopeLoop);
        }
        skipSet.add(tfNode.getName());
    }
    // time to skip some Switch calls
    int switchCnt = 0;
    for (; currIndex.get() < nodes.size(); currIndex.incrementAndGet()) {
        val tfNode = nodes.get(currIndex.get());
        // we're parsing up to condition
        if (!tfNode.getOp().equalsIgnoreCase("Switch"))
            break;
        switchCnt++;
        skipSet.add(tfNode.getName());
    }
    // now we're parsing Identity step
    int identityCnt = 0;
    for (; currIndex.get() < nodes.size(); currIndex.incrementAndGet()) {
        val tfNode = nodes.get(currIndex.get());
        if (!tfNode.getOp().equalsIgnoreCase("Identity")) {
            break;
        }
        val func = DifferentialFunctionClassHolder.getInstance().getInstance(TFGraphMapper.getInstance().getMappedOp(tfNode.getOp()).opName());
        func.initFromTensorFlow(tfNode, initWith, nodeDef.getAttrMap(), graph);
        func.setSameDiff(scopeLoop);
        val variables = new SDVariable[tfNode.getInputCount()];
        for (int i = 0; i < tfNode.getInputCount(); i++) {
            val testVar = initWith.getVariable(TFGraphMapper.getInstance().getNodeName(tfNode.getInput(i)));
            if (testVar == null) {
                variables[i] = initWith.var(tfNode.getInput(i), null, new ZeroInitScheme());
                scopeCondition.var(variables[i]);
                scopeLoop.var(variables[i]);
                continue;
            } else {
                variables[i] = initWith.getVariable(TFGraphMapper.getInstance().getNodeName(tfNode.getInput(i)));
                scopeCondition.var(variables[i]);
                scopeLoop.var(variables[i]);
            }
        }
        scopeLoop.addArgsFor(variables, func);
        skipSet.add(tfNode.getName());
    }
    // parsing body scope
    for (; currIndex.get() < nodes.size(); currIndex.incrementAndGet()) {
        val tfNode = nodes.get(currIndex.get());
        if (skipSet.contains(tfNode.getName())) {
            log.info("Skipping: {}", tfNode.getName());
            continue;
        }
        if (tfNode.getOp().equalsIgnoreCase("NextIteration")) {
            // skipSet.add(tfNode.getName());
            break;
        }
        if (skipSet.contains(tfNode.getName())) {
            log.info("Skipping: {}", tfNode.getName());
            continue;
        }
        boolean isConst = tfNode.getOp().equalsIgnoreCase("const");
        boolean isVar = tfNode.getOp().startsWith("VariableV");
        boolean isPlaceholder = tfNode.getOp().startsWith("Placeholder");
        if (isConst || isVar || isPlaceholder) {
            val var = scopeLoop.var(tfNode.getName(), null, new ZeroInitScheme());
            log.info("Adding body var [{}]", var.getVarName());
        } else {
            log.info("starting on [{}]: {}", tfNode.getName(), tfNode.getOp());
            if (tfNode.getOp().equalsIgnoreCase("enter")) {
                log.info("NEW LOOP ----------------------------------------");
                val func = new While(currIndex);
                func.doImport(nodeDef, initWith, attributesForNode, graph, skipSet, currIndex);
                func.setSameDiff(initWith);
                log.info("END LOOP ----------------------------------------");
            } else {
                val func = DifferentialFunctionClassHolder.getInstance().getInstance(TFGraphMapper.getInstance().getMappedOp(tfNode.getOp()).opName());
                func.initFromTensorFlow(tfNode, initWith, nodeDef.getAttrMap(), graph);
                func.setSameDiff(scopeCondition);
                val variables = new SDVariable[tfNode.getInputCount()];
                for (int i = 0; i < tfNode.getInputCount(); i++) {
                    val name = TFGraphMapper.getInstance().getNodeName(tfNode.getInput(i));
                    variables[i] = scopeCondition.getVariable(name);
                    if (variables[i] == null) {
                        if (scopeLoop.getVariable(name) == null)
                            variables[i] = scopeCondition.var(initWith.getVariable(name));
                        else if (scopeLoop.getVariable(name) != null)
                            variables[i] = scopeLoop.getVariable(name);
                        else
                            variables[i] = scopeLoop.var(name, Nd4j.scalar(1.0));
                    }
                }
                scopeLoop.addArgsFor(variables, func);
            }
        }
        skipSet.add(tfNode.getName());
    }
    val returnInputs = new ArrayList<SDVariable>();
    val returnOutputs = new ArrayList<SDVariable>();
    // mapping NextIterations, to Return op
    for (; currIndex.get() < nodes.size(); currIndex.incrementAndGet()) {
        val tfNode = nodes.get(currIndex.get());
        if (!tfNode.getOp().equalsIgnoreCase("NextIteration"))
            break;
        skipSet.add(tfNode.getName());
        val inputName = TFGraphMapper.getInstance().getNodeName(tfNode.getName());
        val input = initWith.getVariable(inputName) == null ? initWith.var(inputName, null, new ZeroInitScheme()) : initWith.getVariable(inputName);
        returnInputs.add(input);
    }
    this.outputVars = returnOutputs.toArray(new SDVariable[returnOutputs.size()]);
    this.inputVars = returnInputs.toArray(new SDVariable[returnInputs.size()]);
    initWith.addArgsFor(inputVars, this);
    initWith.addOutgoingFor(outputVars, this);
    // we should also map While/Exit to libnd4j while
    int exitCnt = 0;
    for (; currIndex.get() < nodes.size(); currIndex.incrementAndGet()) {
        val tfNode = nodes.get(currIndex.get());
        if (!tfNode.getOp().equalsIgnoreCase("Exit")) {
            // skipSet.add(tfNode.getName());
            break;
        }
        skipSet.add(tfNode.getName());
        val inputName = TFGraphMapper.getInstance().getNodeName(tfNode.getName());
        val input = initWith.getVariable(inputName) == null ? initWith.var(inputName, null, new ZeroInitScheme()) : initWith.getVariable(inputName);
    }
    // the output of the condition should always be a singular scalar
    // this is a safe assumption
    val conditionVars = scopeCondition.functions();
    if (conditionVars.length < 1) {
        throw new ND4JIllegalArgumentException("No functions found!");
    }
    this.targetBoolean = conditionVars[conditionVars.length - 1].outputVariables()[0];
    log.info("-------------------------------------------");
}
Also used : SDVariable(org.nd4j.autodiff.samediff.SDVariable) ZeroInitScheme(org.nd4j.weightinit.impl.ZeroInitScheme) ND4JIllegalArgumentException(org.nd4j.linalg.exception.ND4JIllegalArgumentException)

Example 4 with ZeroInitScheme

use of org.nd4j.weightinit.impl.ZeroInitScheme in project nd4j by deeplearning4j.

the class While method init.

private void init(String blockName, SameDiff parent, SDVariable[] inputVars, SameDiff.SameDiffConditional predicate, SameDiff.SameDiffFunctionDefinition condition, SameDiff.SameDiffFunctionDefinition trueBody) {
    this.sameDiff = parent;
    this.inputVars = inputVars;
    this.predicate = predicate;
    this.trueBody = trueBody;
    this.blockName = blockName;
    this.dummyResult = parent.var("dummyresult-" + UUID.randomUUID().toString(), new int[] { 1, 1 }, new ZeroInitScheme('f'));
    parent.putFunctionForId(getOwnName(), this);
    parent.addArgsFor(inputVars, this);
    parent.addOutgoingFor(new SDVariable[] { dummyResult }, this);
    // create a samediff sub graph for running just the execution
    // return a reference to the loop for referencing during actual execution
    SameDiff sameDiff = SameDiff.create();
    // store the reference to the result array and the same diff execution instance
    this.targetBoolean = predicate.eval(sameDiff, condition, inputVars);
    this.predicateExecution = sameDiff;
    // store references to the loop body
    String trueBodyName = "true-body-" + UUID.randomUUID().toString();
    this.trueBodyName = trueBodyName;
    // running define function will setup a proper same diff instance
    parent.defineFunction(trueBodyName, trueBody, inputVars);
    parent.defineFunction(blockName, condition, inputVars);
    parent.putSubFunction("predicate-eval-body", sameDiff);
    // get a reference to the actual loop body
    this.loopBodyExecution = parent.getFunction(trueBodyName);
}
Also used : ZeroInitScheme(org.nd4j.weightinit.impl.ZeroInitScheme) SameDiff(org.nd4j.autodiff.samediff.SameDiff)

Example 5 with ZeroInitScheme

use of org.nd4j.weightinit.impl.ZeroInitScheme in project nd4j by deeplearning4j.

the class SameDiffTests method testLinearModule2.

@Test
public void testLinearModule2() {
    Linear linear = Linear.execBuilder().nIn(3).nOut(2).weightInitScheme(new OneInitScheme('f')).biasWeightInitScheme(new ZeroInitScheme('f')).build();
    linear.exec(Nd4j.linspace(1, 6, 6).reshape(2, 3));
    INDArray assertion = Nd4j.create(new double[][] { { 6, 6 }, { 15, 15 } });
    assertEquals(assertion, linear.outputArguments()[0]);
}
Also used : OneInitScheme(org.nd4j.weightinit.impl.OneInitScheme) ZeroInitScheme(org.nd4j.weightinit.impl.ZeroInitScheme) INDArray(org.nd4j.linalg.api.ndarray.INDArray) Linear(org.nd4j.linalg.api.ops.impl.layers.Linear) Test(org.junit.Test)

Aggregations

ZeroInitScheme (org.nd4j.weightinit.impl.ZeroInitScheme)7 lombok.val (lombok.val)2 Test (org.junit.Test)2 SameDiff (org.nd4j.autodiff.samediff.SameDiff)2 Linear (org.nd4j.linalg.api.ops.impl.layers.Linear)2 ND4JIllegalStateException (org.nd4j.linalg.exception.ND4JIllegalStateException)2 HashMap (java.util.HashMap)1 LinkedHashMap (java.util.LinkedHashMap)1 Map (java.util.Map)1 SDVariable (org.nd4j.autodiff.samediff.SDVariable)1 INDArray (org.nd4j.linalg.api.ndarray.INDArray)1 ND4JIllegalArgumentException (org.nd4j.linalg.exception.ND4JIllegalArgumentException)1 OneInitScheme (org.nd4j.weightinit.impl.OneInitScheme)1 UniformInitScheme (org.nd4j.weightinit.impl.UniformInitScheme)1