Search in sources :

Example 1 with DifferentialFunction

use of org.nd4j.autodiff.functions.DifferentialFunction in project nd4j by deeplearning4j.

the class SameDiff method execBackwards.

 * Builds a backwards graph
 * and executes the operations
 * on that graph.
 * @return
public Pair<Map<SDVariable, DifferentialFunction>, List<DifferentialFunction>> execBackwards() {
    final SameDiff outer = this;
    if (getFunction("grad") == null)
        defineFunction("grad", new SameDiffFunctionDefinition() {

            public SDVariable[] define(SameDiff sameDiff, Map<String, INDArray> inputs, SDVariable[] variableInputs) {
                // which will also contain the backward
                if (SameDiff.this.debugMode) {
                List<DifferentialFunction> allFunctions = new ArrayList<>(sameDiff.functionInstancesById.values());
                if (allFunctions.isEmpty()) {
                    throw new ND4JIllegalStateException("No ops found!");
                for (val func : allFunctions) {
                    if (func instanceof SDVariable) {
                    val args = func.args();
                    for (val arg : args) arg.setSameDiff(sameDiff);
                    val outputs = func.outputVariables();
                    for (val output : outputs) output.setSameDiff(sameDiff);
                val initialOuts = allFunctions.get(allFunctions.size() - 1).outputVariables();
                val firstBackward = initialOuts[0];
                // start with scalar backprop
                SDVariable initialGrad = sameDiff.var("one-var", Nd4j.scalar(1.0));
                sameDiff.forwardVarForGrad.put(firstBackward.getVarName(), initialGrad);
                sameDiff.gradients.put(firstBackward.getVarName(), initialGrad);
                SDVariable gradientBackwardsMarker = sameDiff.gradientBackwardsMarker(firstBackward);
                // reinitialize list with all declared variables
                allFunctions = new ArrayList<DifferentialFunction>(sameDiff.functionInstancesById.values());
                for (DifferentialFunction action : allFunctions) {
                    if (action instanceof GradientBackwardsMarker) {
                        log.warn("Action op state is null for " + action.opName());
                    DifferentialFunction currFunction = action;
                    Preconditions.checkState(currFunction.getSameDiff() == sameDiff, "Wrong samediff instance found!");
                    // Preconditions.checkNotNull("Gradient for " + currFunction.opName() + " was null ! " + sameDiff.getVariableForVertexId(currFunction.getVertexId()).getGradient());
                    val args = currFunction.outputVariables();
                    for (val arg : args) {
                        if (arg.getSameDiff() != sameDiff) {
                    List<SDVariable> grads = new ArrayList<>();
                    for (val varToGrad : args) {
                        val grad = varToGrad.gradient();
                        if (grad == null)
                            throw new ND4JIllegalStateException("No gradient found for " + varToGrad.getVarName());
                    List<SDVariable> currFnGrads = currFunction.diff(grads);
                if (sameDiff.isDebugMode()) {
                    // ensure all gradients are present for all variables
                    for (SDVariable sdVariable : variables()) {
                return new SDVariable[] { sameDiff.var("grad", new int[] { 1, 1 }) };
    Pair<Map<SDVariable, DifferentialFunction>, List<DifferentialFunction>> forward = exec("grad");
    SameDiff grad = getFunction("grad");
    if (grad.isDebugMode()) {
        // ensure all gradients are present for all variables
        for (SDVariable sdVariable : grad.variables()) {
    return forward;
Also used : GradientBackwardsMarker(org.nd4j.linalg.api.ops.impl.transforms.gradient.GradientBackwardsMarker) DifferentialFunction(org.nd4j.autodiff.functions.DifferentialFunction) ND4JIllegalStateException(org.nd4j.linalg.exception.ND4JIllegalStateException) IntArrayKeyMap(org.nd4j.linalg.collection.IntArrayKeyMap)

Example 2 with DifferentialFunction

use of org.nd4j.autodiff.functions.DifferentialFunction in project nd4j by deeplearning4j.

the class SameDiff method exec.

 * Creates and executes a list of operations
 * @return
public Pair<Map<SDVariable, DifferentialFunction>, List<DifferentialFunction>> exec() {
    if (!resolvedVariables)
        resolveVariablesWith(new LinkedHashMap<String, INDArray>());
    List<DifferentialFunction> ops = new ArrayList<>();
    // we don't care if this thread had any other FlowPath objects attached. we'll just create new one
    localFlowPath.set(new FlowPath());
    val flowPath = localFlowPath.get();
    Map<SDVariable, DifferentialFunction> opMap = new HashMap<>();
    val funcs = new ArrayList<DifferentialFunction>(functionInstancesById.values());
    boolean onBackward = false;
    // dequeue for Frames (nested, probably)
    val frames = new ArrayDeque<String>();
    // simple flag, set true if within frame
    boolean inFrame = false;
    // yet another flag, to remove LastFrame once we really left last frame
    boolean frameLeft = false;
    int i = 0;
    int exec_counter = 0;
    for (; i < funcs.size(); i++) {
        val opName = funcs.get(i).opName();
        if (!onBackward && opName.equals(new GradientBackwardsMarker().opName())) {
            onBackward = true;
        if (opName.equals(new GradientBackwardsMarker().opName()))
        DifferentialFunction differentialFunction = funcs.get(i);
        val ownName = differentialFunction.getOwnName();
        // just registering function for this pass
        if (differentialFunction instanceof SDVariable) {
        val args = getInputsForFunction(differentialFunction);
        log.debug("Step: {}; Executing op {} for node [{}]", exec_counter, opName, ownName);
        // check if inputs are active nodes. skip step otherwise
        // please note: Exit node can't be skipped, because it's either rewind point or exit loop point
        boolean shouldSkip = false;
        if (differentialFunction instanceof Merge) {
            val arg0 = args[0];
            val arg1 = args[1];
            if (!flowPath.isActive(arg0) && !flowPath.isActive(arg1))
                shouldSkip = true;
        } else {
            if (!(differentialFunction instanceof Exit)) {
                // if we've left Exit nodes, we can finally delete last frame name
                if (frameLeft) {
                    frameLeft = false;
                    val frame_name = frames.removeLast();
                    flowPath.activateFrame(frame_name, false);
                // we must check, if there's inactive nodes used as inputs for this node
                for (val input : args) {
                    if (!flowPath.isActive(input)) {
                        // propagate inactivity
                        flowPath.markActive(differentialFunction.getOwnName(), false);
                        shouldSkip = true;
        if (shouldSkip)
        flowPath.markActive(differentialFunction.getOwnName(), true);
         * This set of operations (Enter/Exit/NextIteration/Exit/Switch) are special snowflakes: they modify graph execution order, and basically used here to replicate TF logic.
         * Since SameDiff itself has own logic for loops and conditionals using Scopes
        if (differentialFunction instanceof LoopCond) {
            // this node just passes single input forward, for future evaluation
            val inputs = getInputVariablesForFunction(differentialFunction);
            val array = inputs[0].getArr();
            variableNameToArr.put(differentialFunction.getOwnName(), array.dup(array.ordering()));
            flowPath.markExecuted(differentialFunction.getOwnName(), true);
            if ((int) array.getDouble(0) == 1) {
                val frameName = frames.getLast();
                // incrementing number of cycles for THIS frame, only if LoopCond is true
        } else if (differentialFunction instanceof Enter) {
            // if (flowPath.wasExecuted(differentialFunction.getOwnName()))
            // continue;
            val inputs = getInputVariablesForFunction(differentialFunction);
            val array = inputs[0].getArr();
            variableNameToArr.put(differentialFunction.getOwnName(), array.dup(array.ordering()));
            flowPath.markExecuted(differentialFunction.getOwnName(), true);
            // frame_name MUST be non-null here
            val frame_name = ((Enter) differentialFunction).getFrameName();
            if (!flowPath.isRegisteredFrame(frame_name)) {
                inFrame = true;
        } else if (differentialFunction instanceof Exit) {
            // this is just exit point of graph: it maps own input to own output or rewinds graph to specific position planned at first NextIteration node
            val frame_name = frames.getLast();
            // saving frame_name for backward pass
            ((Exit) differentialFunction).setFrameName(frame_name);
            if (!flowPath.isFrameActive(frame_name)) {
                flowPath.markActive(differentialFunction.getOwnName(), false);
                // if frame is inactive, lets remove it from queue as well
                frameLeft = true;
            // and if it's TRUE - we're setting applying rewind by setting loop idx and calling continue
            if (flowPath.isRewindPlanned(frame_name)) {
                // just reset loop
                flowPath.planRewind(frame_name, false);
                val currentPosition = i;
                i = flowPath.getRewindPosition(frame_name);
                val startPosition = i + 1;
                flowPath.setRewindPosition(frame_name, -1);
            val inputs = getInputVariablesForFunction(differentialFunction);
            val array = inputs[0].getArr();
            variableNameToArr.put(differentialFunction.getOwnName(), array.dup(array.ordering()));
            flowPath.markExecuted(differentialFunction.getOwnName(), true);
            // now it's safe to remove LastFrame
            frameLeft = true;
        } else if (differentialFunction instanceof NextIteration) {
            // this operations merges own input, and schedules rewind to specific Merge node
            val inputs = getInputVariablesForFunction(differentialFunction);
            val frame_name = frames.getLast();
            val array = inputs[0].getArr();
            variableNameToArr.put(differentialFunction.getOwnName(), array.dup(array.ordering()));
            flowPath.markExecuted(differentialFunction.getOwnName(), true);
            // if NextIteration wasn't skipped with inactive branch, we'll plan rewind for this frame. obviously, only once
            if (!flowPath.isRewindPlanned(frame_name)) {
                flowPath.planRewind(frame_name, true);
        } else if (differentialFunction instanceof Merge) {
            // merge operation takes two inputs, and saves one of them as own output.
            // if SDVariable exists for second input - we use it. First input used otherwise
            val inputs = getInputVariablesForFunction(differentialFunction);
            val frame_name = frames.size() > 0 ? frames.getLast() : null;
            if (frame_name != null)
                flowPath.activateFrame(frame_name, true);
            // frame_name can be null if this merge node is used for something that's not loop. i.e. switch/merge pair
            if (frame_name != null)
                flowPath.setRewindPositionOnce(frame_name, i - 1);
            // NextIteration can have NO frame_name defined. so let's propagate it
            if (inputs.length == 2) {
                val secondArg = functionInstancesById.get(inputs[1].getVarName());
                if (secondArg != null && secondArg instanceof NextIteration) {
                    ((NextIteration) secondArg).setFrameName(frame_name);
            // we must check second input first here
            if (flowPath.wasExecuted(inputs[1].getVarName())) {
                // propagate second input
                val array = inputs[1].getArr();
                variableNameToArr.put(differentialFunction.getOwnName(), array.dup(array.ordering()));
                // nullify executed mark
                flowPath.markExecuted(inputs[1].getVarName(), false);
            } else {
                // propagate first input
                val array = inputs[0].getArr();
                variableNameToArr.put(differentialFunction.getOwnName(), array.dup(array.ordering()));
            flowPath.markExecuted(differentialFunction.getOwnName(), true);
        } else if (differentialFunction instanceof Switch) {
            // switch takes 2 inputs: actual input and boolean scalar. If scalar is false, input is saved as output:0, if scalar is true, input is saved as output:1
            ((CustomOp) differentialFunction).populateInputsAndOutputsFromSameDiff();
            val inputs = getInputVariablesForFunction(differentialFunction);
            val input = inputs[0].getArr();
            val bool = inputs[1].getArr();
            // basically we're setting one of the graph branches inactive. branch 0 for false, branch 1 for true
            if ((int) bool.getDouble(0) == 0) {
                // false step, we'll propagate output:0 here
                flowPath.setActiveBranch(differentialFunction.getOwnName(), 0);
                flowPath.markActive(differentialFunction.getOwnName(), true);
                flowPath.markActive(differentialFunction.getOwnName() + ":1", false);
                variableNameToArr.put(differentialFunction.getOwnName(), input.dup(input.ordering()));
            } else {
                // true step, we'll propagate output:1 here
                flowPath.setActiveBranch(differentialFunction.getOwnName(), 1);
                variableNameToArr.put(differentialFunction.getOwnName() + ":1", input.dup(input.ordering()));
                flowPath.markActive(differentialFunction.getOwnName(), false);
                flowPath.markActive(differentialFunction.getOwnName() + ":1", true);
            flowPath.markExecuted(differentialFunction.getOwnName(), true);
        } else if (differentialFunction instanceof If) {
            If ifOp = (If) differentialFunction;
            if (!onBackward) {
                // and possible later processing.
                if (ifOp.getTargetBoolean().getArr().sumNumber().doubleValue() > 0) {
                } else {
            } else {
                if (ifOp.getTrueBodyExecuted() != null) {
                    Pair<Map<SDVariable, DifferentialFunction>, List<DifferentialFunction>> execBackwards = null;
                    List<SDVariable> variablesForFunctions = null;
                    if (ifOp.getTrueBodyExecuted()) {
                        execBackwards = ifOp.getLoopBodyExecution().execBackwards();
                        variablesForFunctions = ifOp.getLoopBodyExecution().getVariablesAssociatedWithFunctions(execBackwards.getRight());
                    } else {
                        execBackwards = ifOp.getFalseBodyExecution().execBackwards();
                        variablesForFunctions = ifOp.getFalseBodyExecution().getVariablesAssociatedWithFunctions(execBackwards.getRight());
                     * Maps the variables from the child namespace body to
                     * the parent. This allows access to the underlying ndarray
                     * and returning a valid variable reference for autodiff.
                    for (SDVariable variable : variablesForFunctions) {
                        SDVariable proxyVar = var(variable);
                } else
                    throw new ND4JIllegalStateException("No body was run.");
            flowPath.markExecuted(differentialFunction.getOwnName(), true);
        } else if (differentialFunction instanceof While) {
            While whileOp = (While) differentialFunction;
            if (!onBackward) {
                SameDiff execBody = whileOp.getLoopBodyExecution();
                // depending on the block add the proper graph body to this for persistence
                // and possible later processing.
                // note that we need to update the graph predicate by running the execution
                while (whileOp.getTargetBoolean().getArr().sumNumber().doubleValue() > 0) {
                    // run the body
                    // update the predicate
                List<SDVariable> outputs = new ArrayList<>();
                val outputFuncArgs = new ArrayList<>(execBody.functionInstancesById.values()).get(execBody.functionInstancesById.values().size() - 1).outputVariables();
                whileOp.setOutputVars(outputs.toArray(new SDVariable[outputs.size()]));
            } else {
                 * Note: Need to accumulate gradients.
                 * Multiply each value by the number of times looped.
                 * This approximates accumulating the gradient
                 * across a number of loop cycles.
                 * We only compute the gradient for the internal loop once
                 * and from that we multiply the gradient by 5.
                Pair<Map<SDVariable, DifferentialFunction>, List<DifferentialFunction>> mapListPair = whileOp.getLoopBodyExecution().execBackwards();
                for (SDVariable variable : mapListPair.getFirst().keySet()) {
            flowPath.markExecuted(differentialFunction.getOwnName(), true);
        } else if (differentialFunction instanceof CustomOp) {
            DynamicCustomOp customOp = (DynamicCustomOp) differentialFunction;
                if (customOp instanceof LessThanOrEqual) {
          "Step: {}; InnerCondition: {} <= {} = {}", exec_counter, customOp.getInputArgument(0), customOp.getInputArgument(1), customOp.getOutputArgument(0));
                } else if (customOp instanceof LessThan) {
          "Step: {}; OuterCondition: {} <= {} = {}", exec_counter, customOp.getInputArgument(0), customOp.getInputArgument(1), customOp.getOutputArgument(0));
            flowPath.markExecuted(differentialFunction.getOwnName(), true);
        } else if (differentialFunction instanceof Op) {
            val inputs = getInputVariablesForFunction(differentialFunction);
            Op op = (Op) differentialFunction;
            // ops in differential function might have stale NDArrays used. we should renew them
            if (inputs.length == 2)
            if (differentialFunction.getDimensions() == null)
            else if (op.isExecSpecial()) {
            } else {
                int[] axes = differentialFunction.getDimensions();
                if (differentialFunction instanceof Accumulation) {
                    Accumulation accumulation = (Accumulation) differentialFunction;
                    Nd4j.getExecutioner().exec(accumulation, axes);
                    if (differentialFunction.outputVariables()[0].getArr() == null) {
                        val var = differentialFunction.outputVariables()[0];
                        updateArrayForVarName(var.getVarName(), accumulation.z());
                        updateShapeForVarName(var.getVarName(), accumulation.z().shape());
                } else if (differentialFunction instanceof BroadcastOp) {
                    BroadcastOp broadcastOp = (BroadcastOp) differentialFunction;
                    Nd4j.getExecutioner().exec(broadcastOp, axes);
                } else if (differentialFunction instanceof GradientOp) {
                } else if (differentialFunction instanceof IndexAccumulation) {
                    IndexAccumulation indexAccumulation = (IndexAccumulation) differentialFunction;
                    Nd4j.getExecutioner().exec(indexAccumulation, axes);
                } else if (differentialFunction instanceof TransformOp) {
                    TransformOp t = (TransformOp) differentialFunction;
                    Nd4j.getExecutioner().exec(t, axes);
            flowPath.markExecuted(differentialFunction.getOwnName(), true);
    // debug
    // printFunction(differentialFunction);
    return new Pair<>(opMap, ops);
Also used : FlowPath(org.nd4j.autodiff.samediff.flow.FlowPath) Pair(org.nd4j.linalg.primitives.Pair) GradientBackwardsMarker(org.nd4j.linalg.api.ops.impl.transforms.gradient.GradientBackwardsMarker) While(org.nd4j.linalg.api.ops.impl.controlflow.While) DifferentialFunction(org.nd4j.autodiff.functions.DifferentialFunction) ND4JIllegalStateException(org.nd4j.linalg.exception.ND4JIllegalStateException) If(org.nd4j.linalg.api.ops.impl.controlflow.If) IntArrayKeyMap(org.nd4j.linalg.collection.IntArrayKeyMap)

Example 3 with DifferentialFunction

use of org.nd4j.autodiff.functions.DifferentialFunction in project nd4j by deeplearning4j.

the class SameDiff method addArgsFor.

 * Adds incoming args to the graph
 * @param variables
 * @param function
public void addArgsFor(String[] variables, DifferentialFunction function) {
    if (function.getOwnName() == null)
        throw new ND4JIllegalStateException("Instance id can not be null. Function not initialized properly");
    // double check if function contains placeholder args
    for (val varName : variables) {
        if (isPlaceHolder(varName)) {
    incomingArgs.put(variables, function);
    incomingArgsReverse.put(function.getOwnName(), variables);
    for (val variableName : variables) {
        List<DifferentialFunction> funcs = functionsArgsFor.get(variableName);
        if (funcs == null) {
            funcs = new ArrayList<>();
            functionsArgsFor.put(variableName, funcs);
Also used : DifferentialFunction(org.nd4j.autodiff.functions.DifferentialFunction) ND4JIllegalStateException(org.nd4j.linalg.exception.ND4JIllegalStateException)

Example 4 with DifferentialFunction

use of org.nd4j.autodiff.functions.DifferentialFunction in project nd4j by deeplearning4j.

the class SameDiff method eval.

 * Evaluate the given inputs
 * based on the current graph
 * @param inputs the inputs to evaluate
 * @return
public INDArray[] eval(Map<String, INDArray> inputs) {
    SameDiff execPipeline = dup();
    List<DifferentialFunction> opExecAction = execPipeline.exec().getRight();
    if (opExecAction.isEmpty())
        throw new IllegalStateException("No ops found to execute.");
    INDArray[] ret = new INDArray[opExecAction.size()];
    for (int i = 0; i < ret.length; i++) {
        val varName = opExecAction.get(i).outputVariables()[0].getVarName();
        ret[i] = execPipeline.getArrForVarName(varName);
    return ret;
Also used : ND4JIllegalStateException(org.nd4j.linalg.exception.ND4JIllegalStateException) INDArray(org.nd4j.linalg.api.ndarray.INDArray) DifferentialFunction(org.nd4j.autodiff.functions.DifferentialFunction)

Example 5 with DifferentialFunction

use of org.nd4j.autodiff.functions.DifferentialFunction in project nd4j by deeplearning4j.

the class SameDiff method addOutgoingFor.

 * Adds outgoing arguments to the graph.
 * Also checks for input arguments
 * and updates the graph adding an appropriate edge
 * when the full graph is declared.
 * @param varNames
 * @param function
public void addOutgoingFor(String[] varNames, DifferentialFunction function) {
    if (function.getOwnName() == null)
        throw new ND4JIllegalStateException("Instance id can not be null. Function not initialized properly");
    if (outgoingArgsReverse.containsKey(function.getOwnName())) {
        throw new ND4JIllegalStateException("Outgoing arguments already declared for " + function);
    if (varNames == null)
        throw new ND4JIllegalStateException("Var names can not be null!");
    for (int i = 0; i < varNames.length; i++) {
        if (varNames[i] == null)
            throw new ND4JIllegalStateException("Variable name elements can not be null!");
    outgoingArgsReverse.put(function.getOwnName(), varNames);
    outgoingArgs.put(varNames, function);
    for (val resultName : varNames) {
        List<DifferentialFunction> funcs = functionOutputFor.get(resultName);
        if (funcs == null) {
            funcs = new ArrayList<>();
            functionOutputFor.put(resultName, funcs);
Also used : DifferentialFunction(org.nd4j.autodiff.functions.DifferentialFunction) ND4JIllegalStateException(org.nd4j.linalg.exception.ND4JIllegalStateException)


DifferentialFunction (org.nd4j.autodiff.functions.DifferentialFunction)18 INDArray (org.nd4j.linalg.api.ndarray.INDArray)9 Test (org.junit.Test)7 ND4JIllegalStateException (org.nd4j.linalg.exception.ND4JIllegalStateException)6 ArrayList (java.util.ArrayList)3 lombok.val (lombok.val)3 SDVariable (org.nd4j.autodiff.samediff.SDVariable)3 SameDiff (org.nd4j.autodiff.samediff.SameDiff)2 NoOpNameFoundException (org.nd4j.imports.NoOpNameFoundException)2 DynamicCustomOp (org.nd4j.linalg.api.ops.DynamicCustomOp)2 Op (org.nd4j.linalg.api.ops.Op)2 GradientBackwardsMarker (org.nd4j.linalg.api.ops.impl.transforms.gradient.GradientBackwardsMarker)2 IntArrayKeyMap (org.nd4j.linalg.collection.IntArrayKeyMap)2 Set (java.util.Set)1 AtomicInteger (java.util.concurrent.atomic.AtomicInteger)1 FlowPath (org.nd4j.autodiff.samediff.flow.FlowPath)1 NdIndexIterator (org.nd4j.linalg.api.iter.NdIndexIterator)1 Variance (org.nd4j.linalg.api.ops.impl.accum.Variance)1 If (org.nd4j.linalg.api.ops.impl.controlflow.If)1 While (org.nd4j.linalg.api.ops.impl.controlflow.While)1