use of org.nd4j.linalg.exception.ND4JIllegalStateException in project nd4j by deeplearning4j.
the class SameDiff method addOutgoingFor.
/**
* Adds outgoing arguments to the graph.
* Also checks for input arguments
* and updates the graph adding an appropriate edge
* when the full graph is declared.
*
* @param varNames
* @param function
*/
public void addOutgoingFor(String[] varNames, DifferentialFunction function) {
if (function.getOwnName() == null)
throw new ND4JIllegalStateException("Instance id can not be null. Function not initialized properly");
if (outgoingArgsReverse.containsKey(function.getOwnName())) {
throw new ND4JIllegalStateException("Outgoing arguments already declared for " + function);
}
if (varNames == null)
throw new ND4JIllegalStateException("Var names can not be null!");
for (int i = 0; i < varNames.length; i++) {
if (varNames[i] == null)
throw new ND4JIllegalStateException("Variable name elements can not be null!");
}
outgoingArgsReverse.put(function.getOwnName(), varNames);
outgoingArgs.put(varNames, function);
for (val resultName : varNames) {
List<DifferentialFunction> funcs = functionOutputFor.get(resultName);
if (funcs == null) {
funcs = new ArrayList<>();
functionOutputFor.put(resultName, funcs);
}
funcs.add(function);
}
}
use of org.nd4j.linalg.exception.ND4JIllegalStateException in project nd4j by deeplearning4j.
the class OnnxGraphMapper method getNDArrayFromTensor.
@Override
public INDArray getNDArrayFromTensor(String tensorName, OnnxProto3.TypeProto.Tensor tensorProto, OnnxProto3.GraphProto graph) {
DataBuffer.Type type = dataTypeForTensor(tensorProto);
if (!tensorProto.isInitialized()) {
throw new ND4JIllegalStateException("Unable to retrieve ndarray. Tensor was not initialized");
}
OnnxProto3.TensorProto tensor = null;
for (int i = 0; i < graph.getInitializerCount(); i++) {
val initializer = graph.getInitializer(i);
if (initializer.getName().equals(tensorName)) {
tensor = initializer;
break;
}
}
if (tensor == null)
return null;
ByteString bytes = tensor.getRawData();
ByteBuffer byteBuffer = bytes.asReadOnlyByteBuffer().order(ByteOrder.nativeOrder());
ByteBuffer directAlloc = ByteBuffer.allocateDirect(byteBuffer.capacity()).order(ByteOrder.nativeOrder());
directAlloc.put(byteBuffer);
directAlloc.rewind();
int[] shape = getShapeFromTensor(tensorProto);
DataBuffer buffer = Nd4j.createBuffer(directAlloc, type, ArrayUtil.prod(shape));
INDArray arr = Nd4j.create(buffer).reshape(shape);
return arr;
}
use of org.nd4j.linalg.exception.ND4JIllegalStateException in project nd4j by deeplearning4j.
the class TFGraphMapper method mapProperty.
@Override
public void mapProperty(String name, DifferentialFunction on, NodeDef node, GraphDef graph, SameDiff sameDiff, Map<String, Map<String, PropertyMapping>> propertyMappingsForFunction) {
if (node == null) {
throw new ND4JIllegalStateException("No node found for name " + name);
}
val mapping = propertyMappingsForFunction.get(getOpType(node)).get(name);
val fields = DifferentialFunctionClassHolder.getInstance().getFieldsForFunction(on);
if (mapping.getTfInputPosition() != null && mapping.getTfInputPosition() < node.getInputCount()) {
int tfMappingIdx = mapping.getTfInputPosition();
if (tfMappingIdx < 0)
tfMappingIdx += node.getInputCount();
val input = node.getInput(tfMappingIdx);
val inputNode = TFGraphMapper.getInstance().getNodeWithNameFromGraph(graph, input);
INDArray arr = getArrayFrom(inputNode, graph);
if (arr == null) {
arr = sameDiff.getArrForVarName(input);
}
if (arr == null && inputNode != null) {
sameDiff.addPropertyToResolve(on, name);
sameDiff.addVariableMappingForField(on, name, inputNode.getName());
return;
} else if (inputNode == null) {
sameDiff.addAsPlaceHolder(input);
return;
}
val field = fields.get(name);
val type = field.getType();
if (type.equals(int[].class)) {
on.setValueFor(field, arr.data().asInt());
} else if (type.equals(int.class) || type.equals(long.class) || type.equals(Long.class) || type.equals(Integer.class)) {
if (mapping.getShapePosition() != null) {
on.setValueFor(field, arr.size(mapping.getShapePosition()));
} else
on.setValueFor(field, arr.getInt(0));
} else if (type.equals(float.class) || type.equals(double.class) || type.equals(Float.class) || type.equals(Double.class)) {
on.setValueFor(field, arr.getDouble(0));
}
} else {
val tfMappingAttrName = mapping.getTfAttrName();
if (tfMappingAttrName == null) {
return;
}
if (!node.containsAttr(tfMappingAttrName)) {
return;
}
val attr = node.getAttrOrThrow(tfMappingAttrName);
val type = attr.getType();
if (fields == null) {
throw new ND4JIllegalStateException("No fields found for op " + mapping);
}
if (mapping.getPropertyNames() == null) {
throw new ND4JIllegalStateException("no property found for " + name + " and op " + on.opName());
}
val field = fields.get(mapping.getPropertyNames()[0]);
Object valueToSet = null;
switch(type) {
case DT_BOOL:
valueToSet = attr.getB();
break;
case DT_INT8:
valueToSet = attr.getI();
break;
case DT_INT16:
valueToSet = attr.getI();
break;
case DT_INT32:
valueToSet = attr.getI();
break;
case DT_FLOAT:
valueToSet = attr.getF();
break;
case DT_DOUBLE:
valueToSet = attr.getF();
break;
case DT_STRING:
valueToSet = attr.getS();
break;
case DT_INT64:
valueToSet = attr.getI();
break;
}
if (field != null && valueToSet != null)
on.setValueFor(field, valueToSet);
}
}
use of org.nd4j.linalg.exception.ND4JIllegalStateException in project nd4j by deeplearning4j.
the class TFGraphMapper method mapNodeType.
@Override
public void mapNodeType(NodeDef tfNode, ImportState<GraphDef, NodeDef> importState) {
if (shouldSkip(tfNode) || alreadySeen(tfNode) || isVariableNode(tfNode)) {
return;
}
val diff = importState.getSameDiff();
if (isVariableNode(tfNode)) {
List<Integer> dimensions = new ArrayList<>();
Map<String, AttrValue> attributes = getAttrMap(tfNode);
if (attributes.containsKey(VALUE_ATTR_KEY)) {
diff.var(getName(tfNode), getArrayFrom(tfNode, importState.getGraph()));
} else if (attributes.containsKey(SHAPE_KEY)) {
AttrValue shape = attributes.get(SHAPE_KEY);
int[] shapeArr = getShapeFromAttr(shape);
int dims = shapeArr.length;
if (dims > 0) {
// even vector is 2d in nd4j
if (dims == 1)
dimensions.add(1);
for (int e = 0; e < dims; e++) {
// TODO: eventually we want long shapes :(
dimensions.add(getShapeFromAttr(shape)[e]);
}
}
}
} else if (isPlaceHolder(tfNode)) {
val vertexId = diff.getVariable(getName(tfNode));
diff.addAsPlaceHolder(vertexId.getVarName());
} else {
val opName = tfNode.getOp();
val nodeName = tfNode.getName();
// FIXME: early draft
// conditional import
/*
if (nodeName.startsWith("cond") && nodeName.contains("/")) {
val str = nodeName.replaceAll("/.*$","");
importCondition(str, tfNode, importState);
seenNodes.add(nodeName);
return;
} else if (nodeName.startsWith("while")) {
// while loop import
return;
}
*/
val differentialFunction = DifferentialFunctionClassHolder.getInstance().getOpWithTensorflowName(opName);
if (differentialFunction == null) {
throw new ND4JIllegalStateException("No tensorflow op found for " + opName + " possibly missing operation class?");
}
try {
val newInstance = differentialFunction.getClass().newInstance();
val args = new SDVariable[tfNode.getInputCount()];
newInstance.setOwnName(tfNode.getName());
for (int i = 0; i < tfNode.getInputCount(); i++) {
val name = getNodeName(tfNode.getInput(i));
args[i] = diff.getVariable(name);
if (args[i] == null) {
args[i] = diff.var(name, null, new ZeroInitScheme('f'));
diff.addAsPlaceHolder(args[i].getVarName());
}
/**
* Note here that we are associating
* the output/result variable
* with its inputs and notifying
* the variable that it has a place holder argument
* it should resolve before trying to execute
* anything.
*/
if (diff.isPlaceHolder(args[i].getVarName())) {
diff.putPlaceHolderForVariable(args[i].getVarName(), name);
}
}
diff.addArgsFor(args, newInstance);
newInstance.setSameDiff(importState.getSameDiff());
newInstance.initFromTensorFlow(tfNode, diff, getAttrMap(tfNode), importState.getGraph());
mapProperties(newInstance, tfNode, importState.getGraph(), importState.getSameDiff(), newInstance.mappingsForFunction());
importState.getSameDiff().putFunctionForId(newInstance.getOwnName(), newInstance);
// ensure we can track node name to function instance later.
diff.setBaseNameForFunctionInstanceId(tfNode.getName(), newInstance);
diff.addVarNameForImport(tfNode.getName());
} catch (Exception e) {
log.error("Failed with [{}]", opName);
throw new RuntimeException(e);
}
}
}
use of org.nd4j.linalg.exception.ND4JIllegalStateException in project nd4j by deeplearning4j.
the class SDVariable method storeAndAllocateNewArray.
/**
* Allocate and return a new array
* based on the vertex id and weight initialization.
* @return the allocated array
*/
public INDArray storeAndAllocateNewArray() {
val shape = sameDiff.getShapeForVarName(getVarName());
if (getArr() != null && Arrays.equals(getArr().shape(), shape))
return getArr();
if (varName == null)
throw new ND4JIllegalStateException("Unable to store array for null variable name!");
if (shape == null) {
throw new ND4JIllegalStateException("Unable to allocate new array. No shape found for variable " + varName);
}
val arr = getWeightInitScheme().create(shape);
sameDiff.putArrayForVarName(getVarName(), arr);
return arr;
}
Aggregations