Search in sources :

Example 56 with SDVariable

use of org.nd4j.autodiff.samediff.SDVariable in project nd4j by deeplearning4j.

the class Min method doDiff.

@Override
public List<SDVariable> doDiff(List<SDVariable> f1) {
    SDVariable min = outputVariables()[0];
    SDVariable eq1 = sameDiff.eq(larg(), min);
    SDVariable eq2 = sameDiff.eq(rarg(), min);
    return Arrays.asList(eq1.mul(f1.get(0)), eq2.mul(f1.get(0)));
}
Also used : SDVariable(org.nd4j.autodiff.samediff.SDVariable)

Example 57 with SDVariable

use of org.nd4j.autodiff.samediff.SDVariable in project nd4j by deeplearning4j.

the class OldDivOp method doDiff.

@Override
public List<SDVariable> doDiff(List<SDVariable> i_v) {
    SDVariable gradWrtX = f().div(i_v.get(0), rarg());
    SDVariable gradWrtY = f().mul(f().neg(gradWrtX), f().div(larg(), rarg()));
    List<SDVariable> ret = new ArrayList<>(2);
    ret.add(gradWrtX);
    ret.add(gradWrtY);
    return ret;
}
Also used : SDVariable(org.nd4j.autodiff.samediff.SDVariable) ArrayList(java.util.ArrayList)

Example 58 with SDVariable

use of org.nd4j.autodiff.samediff.SDVariable in project nd4j by deeplearning4j.

the class OldFloorDivOp method doDiff.

@Override
public List<SDVariable> doDiff(List<SDVariable> i_v) {
    SDVariable gradWrtX = f().div(i_v.get(0), rarg());
    SDVariable gradWrtY = f().mul(f().neg(gradWrtX), f().div(larg(), rarg()));
    List<SDVariable> ret = new ArrayList<>(2);
    ret.add(gradWrtX);
    ret.add(gradWrtY);
    return ret;
}
Also used : SDVariable(org.nd4j.autodiff.samediff.SDVariable) ArrayList(java.util.ArrayList)

Example 59 with SDVariable

use of org.nd4j.autodiff.samediff.SDVariable in project nd4j by deeplearning4j.

the class RectifedLinear method doDiff.

@Override
public List<SDVariable> doDiff(List<SDVariable> i_v) {
    SDVariable step = new Step(sameDiff, arg(), false, cutoff).outputVariables()[0];
    SDVariable ret = step.mul(i_v.get(0));
    return Arrays.asList(ret);
}
Also used : SDVariable(org.nd4j.autodiff.samediff.SDVariable)

Example 60 with SDVariable

use of org.nd4j.autodiff.samediff.SDVariable in project nd4j by deeplearning4j.

the class While method doImport.

private void doImport(NodeDef nodeDef, SameDiff initWith, Map<String, AttrValue> attributesForNode, GraphDef graph, Set<String> skipSet, AtomicInteger currIndex) {
    val uniqueId = java.util.UUID.randomUUID().toString();
    skipSet.add(nodeDef.getName());
    val scopeCondition = SameDiff.create();
    val scopeLoop = SameDiff.create();
    initWith.putSubFunction("condition-" + uniqueId, scopeCondition);
    initWith.putSubFunction("loopbody-" + uniqueId, scopeLoop);
    this.loopBodyExecution = scopeLoop;
    this.predicateExecution = scopeCondition;
    this.startPosition = currIndex;
    log.info("Adding 2 new scopes for WHILE {}");
    val nodes = graph.getNodeList();
    for (; currIndex.get() < nodes.size(); currIndex.incrementAndGet()) {
        val tfNode = nodes.get(currIndex.get());
        if (!tfNode.getOp().equalsIgnoreCase("enter")) {
            // skipSet.add(tfNode.getName());
            break;
        }
        // if (skipSet.contains(tfNode.getName()))
        // continue;
        skipSet.add(tfNode.getName());
        val vars = new SDVariable[tfNode.getInputCount()];
        for (int e = 0; e < tfNode.getInputCount(); e++) {
            val input = TFGraphMapper.getInstance().getNodeName(tfNode.getInput(e));
            vars[e] = initWith.getVariable(input) == null ? initWith.var(input, null, new ZeroInitScheme()) : initWith.getVariable(input);
            scopeCondition.var(vars[e]);
            scopeLoop.var(vars[e]);
        }
        this.inputVars = vars;
    }
    // now we're skipping Merge step, since we've already captured variables at Enter step
    int mergedCnt = 0;
    for (; currIndex.get() < nodes.size(); currIndex.incrementAndGet()) {
        val tfNode = nodes.get(currIndex.get());
        if (!tfNode.getOp().equalsIgnoreCase("merge")) {
            scopeLoop.var(TFGraphMapper.getInstance().getNodeName(tfNode.getName()), null, new ZeroInitScheme());
            break;
        }
        skipSet.add(tfNode.getName());
        val var = scopeLoop.var(TFGraphMapper.getInstance().getNodeName(tfNode.getName()), null, new ZeroInitScheme());
        scopeCondition.var(var);
        initWith.var(var);
        mergedCnt++;
    }
    // now, we're adding conditional scope
    for (; currIndex.get() < nodes.size(); currIndex.incrementAndGet()) {
        val tfNode = nodes.get(currIndex.get());
        // we're parsing up to condition
        if (tfNode.getOp().equalsIgnoreCase("LoopCond")) {
            skipSet.add(tfNode.getName());
            currIndex.incrementAndGet();
            break;
        }
        boolean isConst = tfNode.getOp().equalsIgnoreCase("const");
        boolean isVar = tfNode.getOp().startsWith("VariableV");
        boolean isPlaceholder = tfNode.getOp().startsWith("Placeholder");
        if (isConst || isVar || isPlaceholder) {
            val var = scopeCondition.var(tfNode.getName(), null, new ZeroInitScheme());
            scopeLoop.var(var);
            initWith.var(var);
            log.info("Adding condition var [{}]", var.getVarName());
        } else if (!skipSet.contains(tfNode.getName())) {
            val func = DifferentialFunctionClassHolder.getInstance().getInstance(TFGraphMapper.getInstance().getMappedOp(tfNode.getOp()).opName());
            func.initFromTensorFlow(tfNode, scopeCondition, nodeDef.getAttrMap(), graph);
            func.setSameDiff(scopeLoop);
        }
        skipSet.add(tfNode.getName());
    }
    // time to skip some Switch calls
    int switchCnt = 0;
    for (; currIndex.get() < nodes.size(); currIndex.incrementAndGet()) {
        val tfNode = nodes.get(currIndex.get());
        // we're parsing up to condition
        if (!tfNode.getOp().equalsIgnoreCase("Switch"))
            break;
        switchCnt++;
        skipSet.add(tfNode.getName());
    }
    // now we're parsing Identity step
    int identityCnt = 0;
    for (; currIndex.get() < nodes.size(); currIndex.incrementAndGet()) {
        val tfNode = nodes.get(currIndex.get());
        if (!tfNode.getOp().equalsIgnoreCase("Identity")) {
            break;
        }
        val func = DifferentialFunctionClassHolder.getInstance().getInstance(TFGraphMapper.getInstance().getMappedOp(tfNode.getOp()).opName());
        func.initFromTensorFlow(tfNode, initWith, nodeDef.getAttrMap(), graph);
        func.setSameDiff(scopeLoop);
        val variables = new SDVariable[tfNode.getInputCount()];
        for (int i = 0; i < tfNode.getInputCount(); i++) {
            val testVar = initWith.getVariable(TFGraphMapper.getInstance().getNodeName(tfNode.getInput(i)));
            if (testVar == null) {
                variables[i] = initWith.var(tfNode.getInput(i), null, new ZeroInitScheme());
                scopeCondition.var(variables[i]);
                scopeLoop.var(variables[i]);
                continue;
            } else {
                variables[i] = initWith.getVariable(TFGraphMapper.getInstance().getNodeName(tfNode.getInput(i)));
                scopeCondition.var(variables[i]);
                scopeLoop.var(variables[i]);
            }
        }
        scopeLoop.addArgsFor(variables, func);
        skipSet.add(tfNode.getName());
    }
    // parsing body scope
    for (; currIndex.get() < nodes.size(); currIndex.incrementAndGet()) {
        val tfNode = nodes.get(currIndex.get());
        if (skipSet.contains(tfNode.getName())) {
            log.info("Skipping: {}", tfNode.getName());
            continue;
        }
        if (tfNode.getOp().equalsIgnoreCase("NextIteration")) {
            // skipSet.add(tfNode.getName());
            break;
        }
        if (skipSet.contains(tfNode.getName())) {
            log.info("Skipping: {}", tfNode.getName());
            continue;
        }
        boolean isConst = tfNode.getOp().equalsIgnoreCase("const");
        boolean isVar = tfNode.getOp().startsWith("VariableV");
        boolean isPlaceholder = tfNode.getOp().startsWith("Placeholder");
        if (isConst || isVar || isPlaceholder) {
            val var = scopeLoop.var(tfNode.getName(), null, new ZeroInitScheme());
            log.info("Adding body var [{}]", var.getVarName());
        } else {
            log.info("starting on [{}]: {}", tfNode.getName(), tfNode.getOp());
            if (tfNode.getOp().equalsIgnoreCase("enter")) {
                log.info("NEW LOOP ----------------------------------------");
                val func = new While(currIndex);
                func.doImport(nodeDef, initWith, attributesForNode, graph, skipSet, currIndex);
                func.setSameDiff(initWith);
                log.info("END LOOP ----------------------------------------");
            } else {
                val func = DifferentialFunctionClassHolder.getInstance().getInstance(TFGraphMapper.getInstance().getMappedOp(tfNode.getOp()).opName());
                func.initFromTensorFlow(tfNode, initWith, nodeDef.getAttrMap(), graph);
                func.setSameDiff(scopeCondition);
                val variables = new SDVariable[tfNode.getInputCount()];
                for (int i = 0; i < tfNode.getInputCount(); i++) {
                    val name = TFGraphMapper.getInstance().getNodeName(tfNode.getInput(i));
                    variables[i] = scopeCondition.getVariable(name);
                    if (variables[i] == null) {
                        if (scopeLoop.getVariable(name) == null)
                            variables[i] = scopeCondition.var(initWith.getVariable(name));
                        else if (scopeLoop.getVariable(name) != null)
                            variables[i] = scopeLoop.getVariable(name);
                        else
                            variables[i] = scopeLoop.var(name, Nd4j.scalar(1.0));
                    }
                }
                scopeLoop.addArgsFor(variables, func);
            }
        }
        skipSet.add(tfNode.getName());
    }
    val returnInputs = new ArrayList<SDVariable>();
    val returnOutputs = new ArrayList<SDVariable>();
    // mapping NextIterations, to Return op
    for (; currIndex.get() < nodes.size(); currIndex.incrementAndGet()) {
        val tfNode = nodes.get(currIndex.get());
        if (!tfNode.getOp().equalsIgnoreCase("NextIteration"))
            break;
        skipSet.add(tfNode.getName());
        val inputName = TFGraphMapper.getInstance().getNodeName(tfNode.getName());
        val input = initWith.getVariable(inputName) == null ? initWith.var(inputName, null, new ZeroInitScheme()) : initWith.getVariable(inputName);
        returnInputs.add(input);
    }
    this.outputVars = returnOutputs.toArray(new SDVariable[returnOutputs.size()]);
    this.inputVars = returnInputs.toArray(new SDVariable[returnInputs.size()]);
    initWith.addArgsFor(inputVars, this);
    initWith.addOutgoingFor(outputVars, this);
    // we should also map While/Exit to libnd4j while
    int exitCnt = 0;
    for (; currIndex.get() < nodes.size(); currIndex.incrementAndGet()) {
        val tfNode = nodes.get(currIndex.get());
        if (!tfNode.getOp().equalsIgnoreCase("Exit")) {
            // skipSet.add(tfNode.getName());
            break;
        }
        skipSet.add(tfNode.getName());
        val inputName = TFGraphMapper.getInstance().getNodeName(tfNode.getName());
        val input = initWith.getVariable(inputName) == null ? initWith.var(inputName, null, new ZeroInitScheme()) : initWith.getVariable(inputName);
    }
    // the output of the condition should always be a singular scalar
    // this is a safe assumption
    val conditionVars = scopeCondition.functions();
    if (conditionVars.length < 1) {
        throw new ND4JIllegalArgumentException("No functions found!");
    }
    this.targetBoolean = conditionVars[conditionVars.length - 1].outputVariables()[0];
    log.info("-------------------------------------------");
}
Also used : SDVariable(org.nd4j.autodiff.samediff.SDVariable) ZeroInitScheme(org.nd4j.weightinit.impl.ZeroInitScheme) ND4JIllegalArgumentException(org.nd4j.linalg.exception.ND4JIllegalArgumentException)

Aggregations

SDVariable (org.nd4j.autodiff.samediff.SDVariable)104 SameDiff (org.nd4j.autodiff.samediff.SameDiff)41 INDArray (org.nd4j.linalg.api.ndarray.INDArray)38 Test (org.junit.Test)36 ArrayList (java.util.ArrayList)18 DynamicCustomOp (org.nd4j.linalg.api.ops.DynamicCustomOp)10 lombok.val (lombok.val)7 LossFunctions (org.nd4j.autodiff.loss.LossFunctions)4 LossInfo (org.nd4j.autodiff.loss.LossInfo)4 BernoulliDistribution (org.nd4j.linalg.api.ops.random.impl.BernoulliDistribution)4 Ignore (org.junit.Ignore)3 DifferentialFunction (org.nd4j.autodiff.functions.DifferentialFunction)3 ND4JIllegalStateException (org.nd4j.linalg.exception.ND4JIllegalStateException)3 Triple (org.nd4j.linalg.primitives.Triple)2 DataOutputStream (java.io.DataOutputStream)1 FileOutputStream (java.io.FileOutputStream)1 ByteBuffer (java.nio.ByteBuffer)1 NoOpNameFoundException (org.nd4j.imports.NoOpNameFoundException)1 NdIndexIterator (org.nd4j.linalg.api.iter.NdIndexIterator)1 TruncateDivOp (org.nd4j.linalg.api.ops.impl.transforms.arithmetic.TruncateDivOp)1