use of org.nd4j.weightinit.impl.ZeroInitScheme in project nd4j by deeplearning4j.
the class TFGraphMapper method mapNodeType.
@Override
public void mapNodeType(NodeDef tfNode, ImportState<GraphDef, NodeDef> importState) {
if (shouldSkip(tfNode) || alreadySeen(tfNode) || isVariableNode(tfNode)) {
return;
}
val diff = importState.getSameDiff();
if (isVariableNode(tfNode)) {
List<Integer> dimensions = new ArrayList<>();
Map<String, AttrValue> attributes = getAttrMap(tfNode);
if (attributes.containsKey(VALUE_ATTR_KEY)) {
diff.var(getName(tfNode), getArrayFrom(tfNode, importState.getGraph()));
} else if (attributes.containsKey(SHAPE_KEY)) {
AttrValue shape = attributes.get(SHAPE_KEY);
int[] shapeArr = getShapeFromAttr(shape);
int dims = shapeArr.length;
if (dims > 0) {
// even vector is 2d in nd4j
if (dims == 1)
dimensions.add(1);
for (int e = 0; e < dims; e++) {
// TODO: eventually we want long shapes :(
dimensions.add(getShapeFromAttr(shape)[e]);
}
}
}
} else if (isPlaceHolder(tfNode)) {
val vertexId = diff.getVariable(getName(tfNode));
diff.addAsPlaceHolder(vertexId.getVarName());
} else {
val opName = tfNode.getOp();
val nodeName = tfNode.getName();
// FIXME: early draft
// conditional import
/*
if (nodeName.startsWith("cond") && nodeName.contains("/")) {
val str = nodeName.replaceAll("/.*$","");
importCondition(str, tfNode, importState);
seenNodes.add(nodeName);
return;
} else if (nodeName.startsWith("while")) {
// while loop import
return;
}
*/
val differentialFunction = DifferentialFunctionClassHolder.getInstance().getOpWithTensorflowName(opName);
if (differentialFunction == null) {
throw new ND4JIllegalStateException("No tensorflow op found for " + opName + " possibly missing operation class?");
}
try {
val newInstance = differentialFunction.getClass().newInstance();
val args = new SDVariable[tfNode.getInputCount()];
newInstance.setOwnName(tfNode.getName());
for (int i = 0; i < tfNode.getInputCount(); i++) {
val name = getNodeName(tfNode.getInput(i));
args[i] = diff.getVariable(name);
if (args[i] == null) {
args[i] = diff.var(name, null, new ZeroInitScheme('f'));
diff.addAsPlaceHolder(args[i].getVarName());
}
/**
* Note here that we are associating
* the output/result variable
* with its inputs and notifying
* the variable that it has a place holder argument
* it should resolve before trying to execute
* anything.
*/
if (diff.isPlaceHolder(args[i].getVarName())) {
diff.putPlaceHolderForVariable(args[i].getVarName(), name);
}
}
diff.addArgsFor(args, newInstance);
newInstance.setSameDiff(importState.getSameDiff());
newInstance.initFromTensorFlow(tfNode, diff, getAttrMap(tfNode), importState.getGraph());
mapProperties(newInstance, tfNode, importState.getGraph(), importState.getSameDiff(), newInstance.mappingsForFunction());
importState.getSameDiff().putFunctionForId(newInstance.getOwnName(), newInstance);
// ensure we can track node name to function instance later.
diff.setBaseNameForFunctionInstanceId(tfNode.getName(), newInstance);
diff.addVarNameForImport(tfNode.getName());
} catch (Exception e) {
log.error("Failed with [{}]", opName);
throw new RuntimeException(e);
}
}
}
use of org.nd4j.weightinit.impl.ZeroInitScheme in project nd4j by deeplearning4j.
the class SameDiffTests method testLinearModule.
@Test
public void testLinearModule() {
int nIn = 5;
Linear linear = Linear.execBuilder().nIn(nIn).nOut(4).weightInitScheme(new UniformInitScheme('f', nIn)).biasWeightInitScheme(new ZeroInitScheme('f')).build();
linear.exec(Nd4j.linspace(1, 20, 20).reshape(4, 5));
assertEquals(1, linear.numOutputArguments());
}
use of org.nd4j.weightinit.impl.ZeroInitScheme in project nd4j by deeplearning4j.
the class While method doImport.
private void doImport(NodeDef nodeDef, SameDiff initWith, Map<String, AttrValue> attributesForNode, GraphDef graph, Set<String> skipSet, AtomicInteger currIndex) {
val uniqueId = java.util.UUID.randomUUID().toString();
skipSet.add(nodeDef.getName());
val scopeCondition = SameDiff.create();
val scopeLoop = SameDiff.create();
initWith.putSubFunction("condition-" + uniqueId, scopeCondition);
initWith.putSubFunction("loopbody-" + uniqueId, scopeLoop);
this.loopBodyExecution = scopeLoop;
this.predicateExecution = scopeCondition;
this.startPosition = currIndex;
log.info("Adding 2 new scopes for WHILE {}");
val nodes = graph.getNodeList();
for (; currIndex.get() < nodes.size(); currIndex.incrementAndGet()) {
val tfNode = nodes.get(currIndex.get());
if (!tfNode.getOp().equalsIgnoreCase("enter")) {
// skipSet.add(tfNode.getName());
break;
}
// if (skipSet.contains(tfNode.getName()))
// continue;
skipSet.add(tfNode.getName());
val vars = new SDVariable[tfNode.getInputCount()];
for (int e = 0; e < tfNode.getInputCount(); e++) {
val input = TFGraphMapper.getInstance().getNodeName(tfNode.getInput(e));
vars[e] = initWith.getVariable(input) == null ? initWith.var(input, null, new ZeroInitScheme()) : initWith.getVariable(input);
scopeCondition.var(vars[e]);
scopeLoop.var(vars[e]);
}
this.inputVars = vars;
}
// now we're skipping Merge step, since we've already captured variables at Enter step
int mergedCnt = 0;
for (; currIndex.get() < nodes.size(); currIndex.incrementAndGet()) {
val tfNode = nodes.get(currIndex.get());
if (!tfNode.getOp().equalsIgnoreCase("merge")) {
scopeLoop.var(TFGraphMapper.getInstance().getNodeName(tfNode.getName()), null, new ZeroInitScheme());
break;
}
skipSet.add(tfNode.getName());
val var = scopeLoop.var(TFGraphMapper.getInstance().getNodeName(tfNode.getName()), null, new ZeroInitScheme());
scopeCondition.var(var);
initWith.var(var);
mergedCnt++;
}
// now, we're adding conditional scope
for (; currIndex.get() < nodes.size(); currIndex.incrementAndGet()) {
val tfNode = nodes.get(currIndex.get());
// we're parsing up to condition
if (tfNode.getOp().equalsIgnoreCase("LoopCond")) {
skipSet.add(tfNode.getName());
currIndex.incrementAndGet();
break;
}
boolean isConst = tfNode.getOp().equalsIgnoreCase("const");
boolean isVar = tfNode.getOp().startsWith("VariableV");
boolean isPlaceholder = tfNode.getOp().startsWith("Placeholder");
if (isConst || isVar || isPlaceholder) {
val var = scopeCondition.var(tfNode.getName(), null, new ZeroInitScheme());
scopeLoop.var(var);
initWith.var(var);
log.info("Adding condition var [{}]", var.getVarName());
} else if (!skipSet.contains(tfNode.getName())) {
val func = DifferentialFunctionClassHolder.getInstance().getInstance(TFGraphMapper.getInstance().getMappedOp(tfNode.getOp()).opName());
func.initFromTensorFlow(tfNode, scopeCondition, nodeDef.getAttrMap(), graph);
func.setSameDiff(scopeLoop);
}
skipSet.add(tfNode.getName());
}
// time to skip some Switch calls
int switchCnt = 0;
for (; currIndex.get() < nodes.size(); currIndex.incrementAndGet()) {
val tfNode = nodes.get(currIndex.get());
// we're parsing up to condition
if (!tfNode.getOp().equalsIgnoreCase("Switch"))
break;
switchCnt++;
skipSet.add(tfNode.getName());
}
// now we're parsing Identity step
int identityCnt = 0;
for (; currIndex.get() < nodes.size(); currIndex.incrementAndGet()) {
val tfNode = nodes.get(currIndex.get());
if (!tfNode.getOp().equalsIgnoreCase("Identity")) {
break;
}
val func = DifferentialFunctionClassHolder.getInstance().getInstance(TFGraphMapper.getInstance().getMappedOp(tfNode.getOp()).opName());
func.initFromTensorFlow(tfNode, initWith, nodeDef.getAttrMap(), graph);
func.setSameDiff(scopeLoop);
val variables = new SDVariable[tfNode.getInputCount()];
for (int i = 0; i < tfNode.getInputCount(); i++) {
val testVar = initWith.getVariable(TFGraphMapper.getInstance().getNodeName(tfNode.getInput(i)));
if (testVar == null) {
variables[i] = initWith.var(tfNode.getInput(i), null, new ZeroInitScheme());
scopeCondition.var(variables[i]);
scopeLoop.var(variables[i]);
continue;
} else {
variables[i] = initWith.getVariable(TFGraphMapper.getInstance().getNodeName(tfNode.getInput(i)));
scopeCondition.var(variables[i]);
scopeLoop.var(variables[i]);
}
}
scopeLoop.addArgsFor(variables, func);
skipSet.add(tfNode.getName());
}
// parsing body scope
for (; currIndex.get() < nodes.size(); currIndex.incrementAndGet()) {
val tfNode = nodes.get(currIndex.get());
if (skipSet.contains(tfNode.getName())) {
log.info("Skipping: {}", tfNode.getName());
continue;
}
if (tfNode.getOp().equalsIgnoreCase("NextIteration")) {
// skipSet.add(tfNode.getName());
break;
}
if (skipSet.contains(tfNode.getName())) {
log.info("Skipping: {}", tfNode.getName());
continue;
}
boolean isConst = tfNode.getOp().equalsIgnoreCase("const");
boolean isVar = tfNode.getOp().startsWith("VariableV");
boolean isPlaceholder = tfNode.getOp().startsWith("Placeholder");
if (isConst || isVar || isPlaceholder) {
val var = scopeLoop.var(tfNode.getName(), null, new ZeroInitScheme());
log.info("Adding body var [{}]", var.getVarName());
} else {
log.info("starting on [{}]: {}", tfNode.getName(), tfNode.getOp());
if (tfNode.getOp().equalsIgnoreCase("enter")) {
log.info("NEW LOOP ----------------------------------------");
val func = new While(currIndex);
func.doImport(nodeDef, initWith, attributesForNode, graph, skipSet, currIndex);
func.setSameDiff(initWith);
log.info("END LOOP ----------------------------------------");
} else {
val func = DifferentialFunctionClassHolder.getInstance().getInstance(TFGraphMapper.getInstance().getMappedOp(tfNode.getOp()).opName());
func.initFromTensorFlow(tfNode, initWith, nodeDef.getAttrMap(), graph);
func.setSameDiff(scopeCondition);
val variables = new SDVariable[tfNode.getInputCount()];
for (int i = 0; i < tfNode.getInputCount(); i++) {
val name = TFGraphMapper.getInstance().getNodeName(tfNode.getInput(i));
variables[i] = scopeCondition.getVariable(name);
if (variables[i] == null) {
if (scopeLoop.getVariable(name) == null)
variables[i] = scopeCondition.var(initWith.getVariable(name));
else if (scopeLoop.getVariable(name) != null)
variables[i] = scopeLoop.getVariable(name);
else
variables[i] = scopeLoop.var(name, Nd4j.scalar(1.0));
}
}
scopeLoop.addArgsFor(variables, func);
}
}
skipSet.add(tfNode.getName());
}
val returnInputs = new ArrayList<SDVariable>();
val returnOutputs = new ArrayList<SDVariable>();
// mapping NextIterations, to Return op
for (; currIndex.get() < nodes.size(); currIndex.incrementAndGet()) {
val tfNode = nodes.get(currIndex.get());
if (!tfNode.getOp().equalsIgnoreCase("NextIteration"))
break;
skipSet.add(tfNode.getName());
val inputName = TFGraphMapper.getInstance().getNodeName(tfNode.getName());
val input = initWith.getVariable(inputName) == null ? initWith.var(inputName, null, new ZeroInitScheme()) : initWith.getVariable(inputName);
returnInputs.add(input);
}
this.outputVars = returnOutputs.toArray(new SDVariable[returnOutputs.size()]);
this.inputVars = returnInputs.toArray(new SDVariable[returnInputs.size()]);
initWith.addArgsFor(inputVars, this);
initWith.addOutgoingFor(outputVars, this);
// we should also map While/Exit to libnd4j while
int exitCnt = 0;
for (; currIndex.get() < nodes.size(); currIndex.incrementAndGet()) {
val tfNode = nodes.get(currIndex.get());
if (!tfNode.getOp().equalsIgnoreCase("Exit")) {
// skipSet.add(tfNode.getName());
break;
}
skipSet.add(tfNode.getName());
val inputName = TFGraphMapper.getInstance().getNodeName(tfNode.getName());
val input = initWith.getVariable(inputName) == null ? initWith.var(inputName, null, new ZeroInitScheme()) : initWith.getVariable(inputName);
}
// the output of the condition should always be a singular scalar
// this is a safe assumption
val conditionVars = scopeCondition.functions();
if (conditionVars.length < 1) {
throw new ND4JIllegalArgumentException("No functions found!");
}
this.targetBoolean = conditionVars[conditionVars.length - 1].outputVariables()[0];
log.info("-------------------------------------------");
}
use of org.nd4j.weightinit.impl.ZeroInitScheme in project nd4j by deeplearning4j.
the class While method init.
private void init(String blockName, SameDiff parent, SDVariable[] inputVars, SameDiff.SameDiffConditional predicate, SameDiff.SameDiffFunctionDefinition condition, SameDiff.SameDiffFunctionDefinition trueBody) {
this.sameDiff = parent;
this.inputVars = inputVars;
this.predicate = predicate;
this.trueBody = trueBody;
this.blockName = blockName;
this.dummyResult = parent.var("dummyresult-" + UUID.randomUUID().toString(), new int[] { 1, 1 }, new ZeroInitScheme('f'));
parent.putFunctionForId(getOwnName(), this);
parent.addArgsFor(inputVars, this);
parent.addOutgoingFor(new SDVariable[] { dummyResult }, this);
// create a samediff sub graph for running just the execution
// return a reference to the loop for referencing during actual execution
SameDiff sameDiff = SameDiff.create();
// store the reference to the result array and the same diff execution instance
this.targetBoolean = predicate.eval(sameDiff, condition, inputVars);
this.predicateExecution = sameDiff;
// store references to the loop body
String trueBodyName = "true-body-" + UUID.randomUUID().toString();
this.trueBodyName = trueBodyName;
// running define function will setup a proper same diff instance
parent.defineFunction(trueBodyName, trueBody, inputVars);
parent.defineFunction(blockName, condition, inputVars);
parent.putSubFunction("predicate-eval-body", sameDiff);
// get a reference to the actual loop body
this.loopBodyExecution = parent.getFunction(trueBodyName);
}
use of org.nd4j.weightinit.impl.ZeroInitScheme in project nd4j by deeplearning4j.
the class SameDiffTests method testLinearModule2.
@Test
public void testLinearModule2() {
Linear linear = Linear.execBuilder().nIn(3).nOut(2).weightInitScheme(new OneInitScheme('f')).biasWeightInitScheme(new ZeroInitScheme('f')).build();
linear.exec(Nd4j.linspace(1, 6, 6).reshape(2, 3));
INDArray assertion = Nd4j.create(new double[][] { { 6, 6 }, { 15, 15 } });
assertEquals(assertion, linear.outputArguments()[0]);
}
Aggregations