use of org.nd4j.autodiff.samediff.SDVariable in project nd4j by deeplearning4j.
the class Min method doDiff.
@Override
public List<SDVariable> doDiff(List<SDVariable> f1) {
SDVariable min = outputVariables()[0];
SDVariable eq1 = sameDiff.eq(larg(), min);
SDVariable eq2 = sameDiff.eq(rarg(), min);
return Arrays.asList(eq1.mul(f1.get(0)), eq2.mul(f1.get(0)));
}
use of org.nd4j.autodiff.samediff.SDVariable in project nd4j by deeplearning4j.
the class OldDivOp method doDiff.
@Override
public List<SDVariable> doDiff(List<SDVariable> i_v) {
SDVariable gradWrtX = f().div(i_v.get(0), rarg());
SDVariable gradWrtY = f().mul(f().neg(gradWrtX), f().div(larg(), rarg()));
List<SDVariable> ret = new ArrayList<>(2);
ret.add(gradWrtX);
ret.add(gradWrtY);
return ret;
}
use of org.nd4j.autodiff.samediff.SDVariable in project nd4j by deeplearning4j.
the class OldFloorDivOp method doDiff.
@Override
public List<SDVariable> doDiff(List<SDVariable> i_v) {
SDVariable gradWrtX = f().div(i_v.get(0), rarg());
SDVariable gradWrtY = f().mul(f().neg(gradWrtX), f().div(larg(), rarg()));
List<SDVariable> ret = new ArrayList<>(2);
ret.add(gradWrtX);
ret.add(gradWrtY);
return ret;
}
use of org.nd4j.autodiff.samediff.SDVariable in project nd4j by deeplearning4j.
the class RectifedLinear method doDiff.
@Override
public List<SDVariable> doDiff(List<SDVariable> i_v) {
SDVariable step = new Step(sameDiff, arg(), false, cutoff).outputVariables()[0];
SDVariable ret = step.mul(i_v.get(0));
return Arrays.asList(ret);
}
use of org.nd4j.autodiff.samediff.SDVariable in project nd4j by deeplearning4j.
the class While method doImport.
private void doImport(NodeDef nodeDef, SameDiff initWith, Map<String, AttrValue> attributesForNode, GraphDef graph, Set<String> skipSet, AtomicInteger currIndex) {
val uniqueId = java.util.UUID.randomUUID().toString();
skipSet.add(nodeDef.getName());
val scopeCondition = SameDiff.create();
val scopeLoop = SameDiff.create();
initWith.putSubFunction("condition-" + uniqueId, scopeCondition);
initWith.putSubFunction("loopbody-" + uniqueId, scopeLoop);
this.loopBodyExecution = scopeLoop;
this.predicateExecution = scopeCondition;
this.startPosition = currIndex;
log.info("Adding 2 new scopes for WHILE {}");
val nodes = graph.getNodeList();
for (; currIndex.get() < nodes.size(); currIndex.incrementAndGet()) {
val tfNode = nodes.get(currIndex.get());
if (!tfNode.getOp().equalsIgnoreCase("enter")) {
// skipSet.add(tfNode.getName());
break;
}
// if (skipSet.contains(tfNode.getName()))
// continue;
skipSet.add(tfNode.getName());
val vars = new SDVariable[tfNode.getInputCount()];
for (int e = 0; e < tfNode.getInputCount(); e++) {
val input = TFGraphMapper.getInstance().getNodeName(tfNode.getInput(e));
vars[e] = initWith.getVariable(input) == null ? initWith.var(input, null, new ZeroInitScheme()) : initWith.getVariable(input);
scopeCondition.var(vars[e]);
scopeLoop.var(vars[e]);
}
this.inputVars = vars;
}
// now we're skipping Merge step, since we've already captured variables at Enter step
int mergedCnt = 0;
for (; currIndex.get() < nodes.size(); currIndex.incrementAndGet()) {
val tfNode = nodes.get(currIndex.get());
if (!tfNode.getOp().equalsIgnoreCase("merge")) {
scopeLoop.var(TFGraphMapper.getInstance().getNodeName(tfNode.getName()), null, new ZeroInitScheme());
break;
}
skipSet.add(tfNode.getName());
val var = scopeLoop.var(TFGraphMapper.getInstance().getNodeName(tfNode.getName()), null, new ZeroInitScheme());
scopeCondition.var(var);
initWith.var(var);
mergedCnt++;
}
// now, we're adding conditional scope
for (; currIndex.get() < nodes.size(); currIndex.incrementAndGet()) {
val tfNode = nodes.get(currIndex.get());
// we're parsing up to condition
if (tfNode.getOp().equalsIgnoreCase("LoopCond")) {
skipSet.add(tfNode.getName());
currIndex.incrementAndGet();
break;
}
boolean isConst = tfNode.getOp().equalsIgnoreCase("const");
boolean isVar = tfNode.getOp().startsWith("VariableV");
boolean isPlaceholder = tfNode.getOp().startsWith("Placeholder");
if (isConst || isVar || isPlaceholder) {
val var = scopeCondition.var(tfNode.getName(), null, new ZeroInitScheme());
scopeLoop.var(var);
initWith.var(var);
log.info("Adding condition var [{}]", var.getVarName());
} else if (!skipSet.contains(tfNode.getName())) {
val func = DifferentialFunctionClassHolder.getInstance().getInstance(TFGraphMapper.getInstance().getMappedOp(tfNode.getOp()).opName());
func.initFromTensorFlow(tfNode, scopeCondition, nodeDef.getAttrMap(), graph);
func.setSameDiff(scopeLoop);
}
skipSet.add(tfNode.getName());
}
// time to skip some Switch calls
int switchCnt = 0;
for (; currIndex.get() < nodes.size(); currIndex.incrementAndGet()) {
val tfNode = nodes.get(currIndex.get());
// we're parsing up to condition
if (!tfNode.getOp().equalsIgnoreCase("Switch"))
break;
switchCnt++;
skipSet.add(tfNode.getName());
}
// now we're parsing Identity step
int identityCnt = 0;
for (; currIndex.get() < nodes.size(); currIndex.incrementAndGet()) {
val tfNode = nodes.get(currIndex.get());
if (!tfNode.getOp().equalsIgnoreCase("Identity")) {
break;
}
val func = DifferentialFunctionClassHolder.getInstance().getInstance(TFGraphMapper.getInstance().getMappedOp(tfNode.getOp()).opName());
func.initFromTensorFlow(tfNode, initWith, nodeDef.getAttrMap(), graph);
func.setSameDiff(scopeLoop);
val variables = new SDVariable[tfNode.getInputCount()];
for (int i = 0; i < tfNode.getInputCount(); i++) {
val testVar = initWith.getVariable(TFGraphMapper.getInstance().getNodeName(tfNode.getInput(i)));
if (testVar == null) {
variables[i] = initWith.var(tfNode.getInput(i), null, new ZeroInitScheme());
scopeCondition.var(variables[i]);
scopeLoop.var(variables[i]);
continue;
} else {
variables[i] = initWith.getVariable(TFGraphMapper.getInstance().getNodeName(tfNode.getInput(i)));
scopeCondition.var(variables[i]);
scopeLoop.var(variables[i]);
}
}
scopeLoop.addArgsFor(variables, func);
skipSet.add(tfNode.getName());
}
// parsing body scope
for (; currIndex.get() < nodes.size(); currIndex.incrementAndGet()) {
val tfNode = nodes.get(currIndex.get());
if (skipSet.contains(tfNode.getName())) {
log.info("Skipping: {}", tfNode.getName());
continue;
}
if (tfNode.getOp().equalsIgnoreCase("NextIteration")) {
// skipSet.add(tfNode.getName());
break;
}
if (skipSet.contains(tfNode.getName())) {
log.info("Skipping: {}", tfNode.getName());
continue;
}
boolean isConst = tfNode.getOp().equalsIgnoreCase("const");
boolean isVar = tfNode.getOp().startsWith("VariableV");
boolean isPlaceholder = tfNode.getOp().startsWith("Placeholder");
if (isConst || isVar || isPlaceholder) {
val var = scopeLoop.var(tfNode.getName(), null, new ZeroInitScheme());
log.info("Adding body var [{}]", var.getVarName());
} else {
log.info("starting on [{}]: {}", tfNode.getName(), tfNode.getOp());
if (tfNode.getOp().equalsIgnoreCase("enter")) {
log.info("NEW LOOP ----------------------------------------");
val func = new While(currIndex);
func.doImport(nodeDef, initWith, attributesForNode, graph, skipSet, currIndex);
func.setSameDiff(initWith);
log.info("END LOOP ----------------------------------------");
} else {
val func = DifferentialFunctionClassHolder.getInstance().getInstance(TFGraphMapper.getInstance().getMappedOp(tfNode.getOp()).opName());
func.initFromTensorFlow(tfNode, initWith, nodeDef.getAttrMap(), graph);
func.setSameDiff(scopeCondition);
val variables = new SDVariable[tfNode.getInputCount()];
for (int i = 0; i < tfNode.getInputCount(); i++) {
val name = TFGraphMapper.getInstance().getNodeName(tfNode.getInput(i));
variables[i] = scopeCondition.getVariable(name);
if (variables[i] == null) {
if (scopeLoop.getVariable(name) == null)
variables[i] = scopeCondition.var(initWith.getVariable(name));
else if (scopeLoop.getVariable(name) != null)
variables[i] = scopeLoop.getVariable(name);
else
variables[i] = scopeLoop.var(name, Nd4j.scalar(1.0));
}
}
scopeLoop.addArgsFor(variables, func);
}
}
skipSet.add(tfNode.getName());
}
val returnInputs = new ArrayList<SDVariable>();
val returnOutputs = new ArrayList<SDVariable>();
// mapping NextIterations, to Return op
for (; currIndex.get() < nodes.size(); currIndex.incrementAndGet()) {
val tfNode = nodes.get(currIndex.get());
if (!tfNode.getOp().equalsIgnoreCase("NextIteration"))
break;
skipSet.add(tfNode.getName());
val inputName = TFGraphMapper.getInstance().getNodeName(tfNode.getName());
val input = initWith.getVariable(inputName) == null ? initWith.var(inputName, null, new ZeroInitScheme()) : initWith.getVariable(inputName);
returnInputs.add(input);
}
this.outputVars = returnOutputs.toArray(new SDVariable[returnOutputs.size()]);
this.inputVars = returnInputs.toArray(new SDVariable[returnInputs.size()]);
initWith.addArgsFor(inputVars, this);
initWith.addOutgoingFor(outputVars, this);
// we should also map While/Exit to libnd4j while
int exitCnt = 0;
for (; currIndex.get() < nodes.size(); currIndex.incrementAndGet()) {
val tfNode = nodes.get(currIndex.get());
if (!tfNode.getOp().equalsIgnoreCase("Exit")) {
// skipSet.add(tfNode.getName());
break;
}
skipSet.add(tfNode.getName());
val inputName = TFGraphMapper.getInstance().getNodeName(tfNode.getName());
val input = initWith.getVariable(inputName) == null ? initWith.var(inputName, null, new ZeroInitScheme()) : initWith.getVariable(inputName);
}
// the output of the condition should always be a singular scalar
// this is a safe assumption
val conditionVars = scopeCondition.functions();
if (conditionVars.length < 1) {
throw new ND4JIllegalArgumentException("No functions found!");
}
this.targetBoolean = conditionVars[conditionVars.length - 1].outputVariables()[0];
log.info("-------------------------------------------");
}
Aggregations