use of org.nd4j.weightinit.impl.ZeroInitScheme in project nd4j by deeplearning4j.
the class SameDiff method generateOutputVariableForOp.
/**
* Generate the variables based on the given input op
* and return the output variable names.
*
* @param function the function to generate the output
* variable names for
* @return the set of names generated for each output of the function.
*/
public SDVariable[] generateOutputVariableForOp(DifferentialFunction function, String baseName) {
// if there is already a base name defined, use that
if (baseName == null || baseName.isEmpty() && getBaseNameForFunction(function) != null)
baseName = getBaseNameForFunction(function);
if (baseName == null)
baseName = function.opName();
val outputShape = function.calculateOutputShape();
if (outputShape == null || outputShape.isEmpty()) {
if (function instanceof CustomOp) {
CustomOp customOp = (CustomOp) function;
val descriptor = customOp.getDescriptor();
// can't guess number of outputs, variable
if (descriptor == null || descriptor.getNumOutputs() <= 0) {
throw new ND4JIllegalStateException("No output variables found!");
} else {
char ordering = 'c';
if (function.args()[0].getArr() != null) {
ordering = function.args()[0].getArr().ordering();
}
SDVariable[] ret = new SDVariable[descriptor.getNumOutputs()];
// dynamic shapes
for (int i = 0; i < ret.length; i++) {
SDVariable checkGet = getVariable(baseName);
if (checkGet == null) {
checkGet = var(generateNewVarName(baseName, i), null, new ZeroInitScheme(ordering));
} else if (i > 0 && !importedVarName.contains(baseName)) {
// need to find a new name
String newName = generateNewVarName(baseName, i);
checkGet = getVariable(newName);
}
if (checkGet == null) {
String newName = generateNewVarName(baseName, i);
checkGet = var(newName, null, new ZeroInitScheme(ordering));
}
ret[i] = checkGet;
}
return ret;
}
} else // this is for unresolved shapes, we know xyz is always 1 output
if (function instanceof BaseOp && outputShape.isEmpty()) {
SDVariable[] ret = new SDVariable[1];
SDVariable checkGet = getVariable(baseName);
char ordering = 'c';
if (function.args()[0].getArr() != null) {
ordering = function.args()[0].getArr().ordering();
}
if (checkGet == null) {
checkGet = var(baseName, null, new ZeroInitScheme(ordering));
} else if (!importedVarName.contains(baseName)) {
// need to find a new name
String newName = generateNewVarName(baseName, 0);
checkGet = var(newName, null, new ZeroInitScheme(ordering));
}
if (checkGet == null) {
checkGet = var(baseName, null, new ZeroInitScheme(ordering));
}
ret[0] = checkGet;
return ret;
}
}
char ordering = 'c';
if (function.args()[0].getArr() != null) {
ordering = function.args()[0].getArr().ordering();
}
SDVariable[] ret = new SDVariable[outputShape.size()];
// ownName/baseName will be used to get variables names
val ownName = function.getOwnName();
val rootName = baseName;
for (int i = 0; i < ret.length; i++) {
val shape = outputShape.get(i);
// it should be: rootName:index. i.e.: split:1, split:2, split:3, split:4 etc
baseName = rootName + (i > 0 ? ":" + i : "");
SDVariable checkGet = getVariable(baseName);
if (checkGet == null) {
// obviously - there's no such var, just add it
checkGet = var(baseName, shape, new ZeroInitScheme(ordering));
} else if (shape != null && !shapeAlreadyExistsForVarName(checkGet.getVarName())) {
// var exists, let's update its shape
putShapeForVarName(checkGet.getVarName(), shape);
} else if (shape != null && shapeAlreadyExistsForVarName(checkGet.getVarName())) {
// no-op.
// TODO: maybe we should check shapes equality here?
// it's either var that already exist, or something bad happening
} else if (!importedVarName.contains(baseName)) {
// FIXME: dead end. it's impossible to get here with null as shape
// need to find a new name
int count = 1;
String name = baseName + "_" + count + (i > 0 ? ":" + i : "");
while (getVariable(name) != null) {
count++;
name = baseName + "_" + count + (i > 0 ? ":" + i : "");
}
if (getVariable(name) != null) {
throw new ND4JIllegalStateException("Converged on already generated variable!");
}
checkGet = var(name, shape, new ZeroInitScheme(ordering));
}
if (checkGet == null) {
checkGet = var(baseName + (i > 0 ? ":" + i : ""), shape, new ZeroInitScheme(ordering));
}
ret[i] = checkGet;
}
return ret;
}
use of org.nd4j.weightinit.impl.ZeroInitScheme in project nd4j by deeplearning4j.
the class BaseGraphMapper method importGraph.
/**
* This method converts given TF
* @param tfGraph
* @return
*/
@Override
public SameDiff importGraph(GRAPH_TYPE tfGraph) {
SameDiff diff = SameDiff.create();
ImportState<GRAPH_TYPE, TENSOR_TYPE> importState = new ImportState<>();
importState.setSameDiff(diff);
importState.setGraph(tfGraph);
val variablesForGraph = variablesForGraph(tfGraph);
importState.setVariables(variablesForGraph);
// for each variable
for (Map.Entry<String, TENSOR_TYPE> entry : variablesForGraph.entrySet()) {
if (dataTypeForTensor(entry.getValue()) == DataBuffer.Type.UNKNOWN) {
val var = importState.getSameDiff().var(entry.getKey(), null, new ZeroInitScheme('c'));
// mark as place holder for validating resolution later.
if (isPlaceHolder(entry.getValue())) {
importState.getSameDiff().addAsPlaceHolder(var.getVarName());
if (var.getShape() != null)
importState.getSameDiff().setOriginalPlaceHolderShape(var.getVarName(), var.getShape());
}
continue;
}
val arr = getNDArrayFromTensor(entry.getKey(), entry.getValue(), tfGraph);
if (arr != null) {
val var = importState.getSameDiff().var(entry.getKey(), arr);
// ensure the array is made available for later processing
diff.associateArrayWithVariable(arr, var);
} else if (getShapeFromTensor(entry.getValue()) == null) {
val var = importState.getSameDiff().var(entry.getKey(), null, new ZeroInitScheme('c'));
// that it isn't a place holder.
if (isPlaceHolder(entry.getValue())) {
val originalShape = getShapeFromTensor(entry.getValue());
importState.getSameDiff().addAsPlaceHolder(var.getVarName());
if (var.getShape() != null)
importState.getSameDiff().setOriginalPlaceHolderShape(var.getVarName(), originalShape);
}
} else {
val originalShape = getShapeFromTensor(entry.getValue());
val var = importState.getSameDiff().var(entry.getKey(), originalShape);
// that it isn't a place holder.
if (isPlaceHolder(entry.getValue())) {
importState.getSameDiff().addAsPlaceHolder(var.getVarName());
importState.getSameDiff().setOriginalPlaceHolderShape(var.getVarName(), originalShape);
}
}
}
// setup vertex ids for names
// handle mapping vertex ids properly
val tfNodesList = getNodeList(tfGraph);
for (NODE_TYPE tfNode : tfNodesList) {
if (!opsToIgnore().contains(getOpType(tfNode)) || isOpIgnoreException(tfNode))
mapNodeType(tfNode, importState);
}
return diff;
}
Aggregations