use of org.nd4j.autodiff.functions.DifferentialFunction in project nd4j by deeplearning4j.
the class SameDiff method execBackwards.
/**
* Builds a backwards graph
* and executes the operations
* on that graph.
*
* @return
*/
public Pair<Map<SDVariable, DifferentialFunction>, List<DifferentialFunction>> execBackwards() {
final SameDiff outer = this;
if (getFunction("grad") == null)
defineFunction("grad", new SameDiffFunctionDefinition() {
@Override
public SDVariable[] define(SameDiff sameDiff, Map<String, INDArray> inputs, SDVariable[] variableInputs) {
// which will also contain the backward
if (SameDiff.this.debugMode) {
sameDiff.enableDebugMode();
}
outer.invokeGraphOn(sameDiff);
List<DifferentialFunction> allFunctions = new ArrayList<>(sameDiff.functionInstancesById.values());
if (allFunctions.isEmpty()) {
throw new ND4JIllegalStateException("No ops found!");
}
for (val func : allFunctions) {
if (func instanceof SDVariable) {
continue;
}
val args = func.args();
for (val arg : args) arg.setSameDiff(sameDiff);
val outputs = func.outputVariables();
for (val output : outputs) output.setSameDiff(sameDiff);
func.setSameDiff(sameDiff);
}
val initialOuts = allFunctions.get(allFunctions.size() - 1).outputVariables();
val firstBackward = initialOuts[0];
// start with scalar backprop
SDVariable initialGrad = sameDiff.var("one-var", Nd4j.scalar(1.0));
sameDiff.forwardVarForGrad.put(firstBackward.getVarName(), initialGrad);
sameDiff.gradients.put(firstBackward.getVarName(), initialGrad);
SDVariable gradientBackwardsMarker = sameDiff.gradientBackwardsMarker(firstBackward);
// reinitialize list with all declared variables
allFunctions = new ArrayList<DifferentialFunction>(sameDiff.functionInstancesById.values());
Collections.reverse(allFunctions);
for (DifferentialFunction action : allFunctions) {
if (action instanceof GradientBackwardsMarker) {
log.warn("Action op state is null for " + action.opName());
continue;
}
DifferentialFunction currFunction = action;
Preconditions.checkState(currFunction.getSameDiff() == sameDiff, "Wrong samediff instance found!");
// Preconditions.checkNotNull("Gradient for " + currFunction.opName() + " was null ! " + sameDiff.getVariableForVertexId(currFunction.getVertexId()).getGradient());
val args = currFunction.outputVariables();
for (val arg : args) {
if (arg.getSameDiff() != sameDiff) {
arg.setSameDiff(sameDiff);
}
}
List<SDVariable> grads = new ArrayList<>();
for (val varToGrad : args) {
val grad = varToGrad.gradient();
if (grad == null)
throw new ND4JIllegalStateException("No gradient found for " + varToGrad.getVarName());
grads.add(grad);
}
List<SDVariable> currFnGrads = currFunction.diff(grads);
}
if (sameDiff.isDebugMode()) {
// ensure all gradients are present for all variables
for (SDVariable sdVariable : variables()) {
sdVariable.gradient();
}
}
return new SDVariable[] { sameDiff.var("grad", new int[] { 1, 1 }) };
}
});
Pair<Map<SDVariable, DifferentialFunction>, List<DifferentialFunction>> forward = exec("grad");
SameDiff grad = getFunction("grad");
if (grad.isDebugMode()) {
// ensure all gradients are present for all variables
for (SDVariable sdVariable : grad.variables()) {
sdVariable.gradient();
}
}
return forward;
}
use of org.nd4j.autodiff.functions.DifferentialFunction in project nd4j by deeplearning4j.
the class SameDiff method exec.
/**
* Creates and executes a list of operations
*
* @return
*/
public Pair<Map<SDVariable, DifferentialFunction>, List<DifferentialFunction>> exec() {
if (!resolvedVariables)
resolveVariablesWith(new LinkedHashMap<String, INDArray>());
List<DifferentialFunction> ops = new ArrayList<>();
// we don't care if this thread had any other FlowPath objects attached. we'll just create new one
localFlowPath.set(new FlowPath());
val flowPath = localFlowPath.get();
Map<SDVariable, DifferentialFunction> opMap = new HashMap<>();
val funcs = new ArrayList<DifferentialFunction>(functionInstancesById.values());
boolean onBackward = false;
// dequeue for Frames (nested, probably)
val frames = new ArrayDeque<String>();
// simple flag, set true if within frame
boolean inFrame = false;
// yet another flag, to remove LastFrame once we really left last frame
boolean frameLeft = false;
int i = 0;
int exec_counter = 0;
for (; i < funcs.size(); i++) {
++exec_counter;
val opName = funcs.get(i).opName();
if (!onBackward && opName.equals(new GradientBackwardsMarker().opName())) {
onBackward = true;
}
if (opName.equals(new GradientBackwardsMarker().opName()))
continue;
DifferentialFunction differentialFunction = funcs.get(i);
val ownName = differentialFunction.getOwnName();
// just registering function for this pass
flowPath.ensureNodeStateExists(differentialFunction.getOwnName());
if (differentialFunction instanceof SDVariable) {
continue;
}
val args = getInputsForFunction(differentialFunction);
log.debug("Step: {}; Executing op {} for node [{}]", exec_counter, opName, ownName);
// check if inputs are active nodes. skip step otherwise
// please note: Exit node can't be skipped, because it's either rewind point or exit loop point
boolean shouldSkip = false;
if (differentialFunction instanceof Merge) {
val arg0 = args[0];
val arg1 = args[1];
if (!flowPath.isActive(arg0) && !flowPath.isActive(arg1))
shouldSkip = true;
} else {
if (!(differentialFunction instanceof Exit)) {
// if we've left Exit nodes, we can finally delete last frame name
if (frameLeft) {
frameLeft = false;
val frame_name = frames.removeLast();
flowPath.activateFrame(frame_name, false);
flowPath.forgetFrame(frame_name);
}
// we must check, if there's inactive nodes used as inputs for this node
for (val input : args) {
if (!flowPath.isActive(input)) {
// propagate inactivity
flowPath.markActive(differentialFunction.getOwnName(), false);
shouldSkip = true;
break;
}
}
}
}
if (shouldSkip)
continue;
differentialFunction.resolvePropertiesFromSameDiffBeforeExecution();
flowPath.markActive(differentialFunction.getOwnName(), true);
/**
* This set of operations (Enter/Exit/NextIteration/Exit/Switch) are special snowflakes: they modify graph execution order, and basically used here to replicate TF logic.
* Since SameDiff itself has own logic for loops and conditionals using Scopes
*/
if (differentialFunction instanceof LoopCond) {
// this node just passes single input forward, for future evaluation
val inputs = getInputVariablesForFunction(differentialFunction);
val array = inputs[0].getArr();
variableNameToArr.put(differentialFunction.getOwnName(), array.dup(array.ordering()));
flowPath.markExecuted(differentialFunction.getOwnName(), true);
if ((int) array.getDouble(0) == 1) {
val frameName = frames.getLast();
// incrementing number of cycles for THIS frame, only if LoopCond is true
flowPath.incrementNumberOfCycles(frameName);
}
} else if (differentialFunction instanceof Enter) {
// if (flowPath.wasExecuted(differentialFunction.getOwnName()))
// continue;
val inputs = getInputVariablesForFunction(differentialFunction);
val array = inputs[0].getArr();
variableNameToArr.put(differentialFunction.getOwnName(), array.dup(array.ordering()));
flowPath.markExecuted(differentialFunction.getOwnName(), true);
// frame_name MUST be non-null here
val frame_name = ((Enter) differentialFunction).getFrameName();
if (!flowPath.isRegisteredFrame(frame_name)) {
flowPath.registerFrame(frame_name);
frames.addLast(frame_name);
inFrame = true;
}
} else if (differentialFunction instanceof Exit) {
// this is just exit point of graph: it maps own input to own output or rewinds graph to specific position planned at first NextIteration node
val frame_name = frames.getLast();
// saving frame_name for backward pass
((Exit) differentialFunction).setFrameName(frame_name);
if (!flowPath.isFrameActive(frame_name)) {
flowPath.markActive(differentialFunction.getOwnName(), false);
// if frame is inactive, lets remove it from queue as well
frameLeft = true;
continue;
}
// and if it's TRUE - we're setting applying rewind by setting loop idx and calling continue
if (flowPath.isRewindPlanned(frame_name)) {
// just reset loop
flowPath.planRewind(frame_name, false);
val currentPosition = i;
i = flowPath.getRewindPosition(frame_name);
val startPosition = i + 1;
flowPath.setRewindPosition(frame_name, -1);
continue;
}
val inputs = getInputVariablesForFunction(differentialFunction);
val array = inputs[0].getArr();
variableNameToArr.put(differentialFunction.getOwnName(), array.dup(array.ordering()));
flowPath.markExecuted(differentialFunction.getOwnName(), true);
// now it's safe to remove LastFrame
frameLeft = true;
} else if (differentialFunction instanceof NextIteration) {
// this operations merges own input, and schedules rewind to specific Merge node
val inputs = getInputVariablesForFunction(differentialFunction);
val frame_name = frames.getLast();
val array = inputs[0].getArr();
variableNameToArr.put(differentialFunction.getOwnName(), array.dup(array.ordering()));
flowPath.markExecuted(differentialFunction.getOwnName(), true);
// if NextIteration wasn't skipped with inactive branch, we'll plan rewind for this frame. obviously, only once
if (!flowPath.isRewindPlanned(frame_name)) {
flowPath.planRewind(frame_name, true);
continue;
}
} else if (differentialFunction instanceof Merge) {
// merge operation takes two inputs, and saves one of them as own output.
// if SDVariable exists for second input - we use it. First input used otherwise
val inputs = getInputVariablesForFunction(differentialFunction);
val frame_name = frames.size() > 0 ? frames.getLast() : null;
if (frame_name != null)
flowPath.activateFrame(frame_name, true);
// frame_name can be null if this merge node is used for something that's not loop. i.e. switch/merge pair
if (frame_name != null)
flowPath.setRewindPositionOnce(frame_name, i - 1);
// NextIteration can have NO frame_name defined. so let's propagate it
if (inputs.length == 2) {
val secondArg = functionInstancesById.get(inputs[1].getVarName());
if (secondArg != null && secondArg instanceof NextIteration) {
((NextIteration) secondArg).setFrameName(frame_name);
}
}
// we must check second input first here
if (flowPath.wasExecuted(inputs[1].getVarName())) {
// propagate second input
val array = inputs[1].getArr();
variableNameToArr.put(differentialFunction.getOwnName(), array.dup(array.ordering()));
// nullify executed mark
flowPath.markExecuted(inputs[1].getVarName(), false);
} else {
// propagate first input
val array = inputs[0].getArr();
variableNameToArr.put(differentialFunction.getOwnName(), array.dup(array.ordering()));
}
flowPath.markExecuted(differentialFunction.getOwnName(), true);
} else if (differentialFunction instanceof Switch) {
// switch takes 2 inputs: actual input and boolean scalar. If scalar is false, input is saved as output:0, if scalar is true, input is saved as output:1
((CustomOp) differentialFunction).populateInputsAndOutputsFromSameDiff();
val inputs = getInputVariablesForFunction(differentialFunction);
val input = inputs[0].getArr();
val bool = inputs[1].getArr();
// basically we're setting one of the graph branches inactive. branch 0 for false, branch 1 for true
if ((int) bool.getDouble(0) == 0) {
// false step, we'll propagate output:0 here
flowPath.setActiveBranch(differentialFunction.getOwnName(), 0);
flowPath.markActive(differentialFunction.getOwnName(), true);
flowPath.markActive(differentialFunction.getOwnName() + ":1", false);
variableNameToArr.put(differentialFunction.getOwnName(), input.dup(input.ordering()));
} else {
// true step, we'll propagate output:1 here
flowPath.setActiveBranch(differentialFunction.getOwnName(), 1);
variableNameToArr.put(differentialFunction.getOwnName() + ":1", input.dup(input.ordering()));
flowPath.markActive(differentialFunction.getOwnName(), false);
flowPath.markActive(differentialFunction.getOwnName() + ":1", true);
}
flowPath.markExecuted(differentialFunction.getOwnName(), true);
} else if (differentialFunction instanceof If) {
If ifOp = (If) differentialFunction;
if (!onBackward) {
ifOp.getPredicateExecution().exec();
// and possible later processing.
if (ifOp.getTargetBoolean().getArr().sumNumber().doubleValue() > 0) {
ifOp.getLoopBodyExecution().exec();
ifOp.exectedTrueOrFalse(true);
} else {
ifOp.getFalseBodyExecution().exec();
ifOp.exectedTrueOrFalse(false);
}
} else {
if (ifOp.getTrueBodyExecuted() != null) {
Pair<Map<SDVariable, DifferentialFunction>, List<DifferentialFunction>> execBackwards = null;
List<SDVariable> variablesForFunctions = null;
if (ifOp.getTrueBodyExecuted()) {
execBackwards = ifOp.getLoopBodyExecution().execBackwards();
variablesForFunctions = ifOp.getLoopBodyExecution().getVariablesAssociatedWithFunctions(execBackwards.getRight());
} else {
execBackwards = ifOp.getFalseBodyExecution().execBackwards();
variablesForFunctions = ifOp.getFalseBodyExecution().getVariablesAssociatedWithFunctions(execBackwards.getRight());
}
/**
* Maps the variables from the child namespace body to
* the parent. This allows access to the underlying ndarray
* and returning a valid variable reference for autodiff.
*/
for (SDVariable variable : variablesForFunctions) {
SDVariable proxyVar = var(variable);
}
} else
throw new ND4JIllegalStateException("No body was run.");
}
flowPath.markExecuted(differentialFunction.getOwnName(), true);
ops.add(differentialFunction);
} else if (differentialFunction instanceof While) {
While whileOp = (While) differentialFunction;
if (!onBackward) {
SameDiff execBody = whileOp.getLoopBodyExecution();
// depending on the block add the proper graph body to this for persistence
// and possible later processing.
// note that we need to update the graph predicate by running the execution
whileOp.getPredicateExecution().exec();
while (whileOp.getTargetBoolean().getArr().sumNumber().doubleValue() > 0) {
// run the body
execBody.exec();
// update the predicate
whileOp.getPredicateExecution().exec();
whileOp.incrementLoopCounter();
}
List<SDVariable> outputs = new ArrayList<>();
val outputFuncArgs = new ArrayList<>(execBody.functionInstancesById.values()).get(execBody.functionInstancesById.values().size() - 1).outputVariables();
outputs.addAll(Arrays.asList(outputFuncArgs));
whileOp.setOutputVars(outputs.toArray(new SDVariable[outputs.size()]));
ops.add(differentialFunction);
} else {
/**
* Note: Need to accumulate gradients.
* Multiply each value by the number of times looped.
* This approximates accumulating the gradient
* across a number of loop cycles.
* We only compute the gradient for the internal loop once
* and from that we multiply the gradient by 5.
*/
Pair<Map<SDVariable, DifferentialFunction>, List<DifferentialFunction>> mapListPair = whileOp.getLoopBodyExecution().execBackwards();
for (SDVariable variable : mapListPair.getFirst().keySet()) {
variable.getArr().muli(whileOp.getNumLooped());
}
}
flowPath.markExecuted(differentialFunction.getOwnName(), true);
} else if (differentialFunction instanceof CustomOp) {
DynamicCustomOp customOp = (DynamicCustomOp) differentialFunction;
customOp.populateInputsAndOutputsFromSameDiff();
customOp.assertValidForExecution();
customOp.updateInputsFromSameDiff();
Nd4j.getExecutioner().exec(customOp);
/*
if (customOp instanceof LessThanOrEqual) {
log.info("Step: {}; InnerCondition: {} <= {} = {}", exec_counter, customOp.getInputArgument(0), customOp.getInputArgument(1), customOp.getOutputArgument(0));
} else if (customOp instanceof LessThan) {
log.info("Step: {}; OuterCondition: {} <= {} = {}", exec_counter, customOp.getInputArgument(0), customOp.getInputArgument(1), customOp.getOutputArgument(0));
}
*/
flowPath.markExecuted(differentialFunction.getOwnName(), true);
ops.add(customOp);
} else if (differentialFunction instanceof Op) {
val inputs = getInputVariablesForFunction(differentialFunction);
Op op = (Op) differentialFunction;
// ops in differential function might have stale NDArrays used. we should renew them
op.setX(inputs[0].getArr());
if (inputs.length == 2)
op.setY(inputs[1].getArr());
if (differentialFunction.getDimensions() == null)
Nd4j.getExecutioner().exec(op);
else if (op.isExecSpecial()) {
op.exec();
} else {
int[] axes = differentialFunction.getDimensions();
if (differentialFunction instanceof Accumulation) {
Accumulation accumulation = (Accumulation) differentialFunction;
Nd4j.getExecutioner().exec(accumulation, axes);
if (differentialFunction.outputVariables()[0].getArr() == null) {
val var = differentialFunction.outputVariables()[0];
updateArrayForVarName(var.getVarName(), accumulation.z());
updateShapeForVarName(var.getVarName(), accumulation.z().shape());
}
} else if (differentialFunction instanceof BroadcastOp) {
BroadcastOp broadcastOp = (BroadcastOp) differentialFunction;
Nd4j.getExecutioner().exec(broadcastOp, axes);
} else if (differentialFunction instanceof GradientOp) {
Nd4j.getExecutioner().exec(op);
} else if (differentialFunction instanceof IndexAccumulation) {
IndexAccumulation indexAccumulation = (IndexAccumulation) differentialFunction;
Nd4j.getExecutioner().exec(indexAccumulation, axes);
} else if (differentialFunction instanceof TransformOp) {
TransformOp t = (TransformOp) differentialFunction;
Nd4j.getExecutioner().exec(t, axes);
}
}
flowPath.markExecuted(differentialFunction.getOwnName(), true);
ops.add(differentialFunction);
}
// debug
// printFunction(differentialFunction);
}
return new Pair<>(opMap, ops);
}
use of org.nd4j.autodiff.functions.DifferentialFunction in project nd4j by deeplearning4j.
the class SameDiff method addArgsFor.
/**
* Adds incoming args to the graph
*
* @param variables
* @param function
*/
public void addArgsFor(String[] variables, DifferentialFunction function) {
if (function.getOwnName() == null)
throw new ND4JIllegalStateException("Instance id can not be null. Function not initialized properly");
// double check if function contains placeholder args
for (val varName : variables) {
if (isPlaceHolder(varName)) {
placeHolderFunctions.add(function.getOwnName());
}
}
incomingArgs.put(variables, function);
incomingArgsReverse.put(function.getOwnName(), variables);
for (val variableName : variables) {
List<DifferentialFunction> funcs = functionsArgsFor.get(variableName);
if (funcs == null) {
funcs = new ArrayList<>();
functionsArgsFor.put(variableName, funcs);
}
funcs.add(function);
}
}
use of org.nd4j.autodiff.functions.DifferentialFunction in project nd4j by deeplearning4j.
the class SameDiff method eval.
/**
* Evaluate the given inputs
* based on the current graph
*
* @param inputs the inputs to evaluate
* @return
*/
public INDArray[] eval(Map<String, INDArray> inputs) {
SameDiff execPipeline = dup();
List<DifferentialFunction> opExecAction = execPipeline.exec().getRight();
if (opExecAction.isEmpty())
throw new IllegalStateException("No ops found to execute.");
INDArray[] ret = new INDArray[opExecAction.size()];
for (int i = 0; i < ret.length; i++) {
val varName = opExecAction.get(i).outputVariables()[0].getVarName();
ret[i] = execPipeline.getArrForVarName(varName);
}
return ret;
}
use of org.nd4j.autodiff.functions.DifferentialFunction in project nd4j by deeplearning4j.
the class SameDiff method addOutgoingFor.
/**
* Adds outgoing arguments to the graph.
* Also checks for input arguments
* and updates the graph adding an appropriate edge
* when the full graph is declared.
*
* @param varNames
* @param function
*/
public void addOutgoingFor(String[] varNames, DifferentialFunction function) {
if (function.getOwnName() == null)
throw new ND4JIllegalStateException("Instance id can not be null. Function not initialized properly");
if (outgoingArgsReverse.containsKey(function.getOwnName())) {
throw new ND4JIllegalStateException("Outgoing arguments already declared for " + function);
}
if (varNames == null)
throw new ND4JIllegalStateException("Var names can not be null!");
for (int i = 0; i < varNames.length; i++) {
if (varNames[i] == null)
throw new ND4JIllegalStateException("Variable name elements can not be null!");
}
outgoingArgsReverse.put(function.getOwnName(), varNames);
outgoingArgs.put(varNames, function);
for (val resultName : varNames) {
List<DifferentialFunction> funcs = functionOutputFor.get(resultName);
if (funcs == null) {
funcs = new ArrayList<>();
functionOutputFor.put(resultName, funcs);
}
funcs.add(function);
}
}
Aggregations