use of org.nd4j.imports.NoOpNameFoundException in project nd4j by deeplearning4j.
the class BaseCompatOp method mappingsForFunction.
@Override
public Map<String, Map<String, PropertyMapping>> mappingsForFunction() {
Map<String, Map<String, PropertyMapping>> ret = new HashMap<>();
Map<String, PropertyMapping> map = new HashMap<>();
val frameNameMapping = PropertyMapping.builder().tfAttrName("frame_name").onnxAttrName(// not sure if it exists in onnx
"frame_name").propertyNames(new String[] { "frameName" }).build();
map.put("frameName", frameNameMapping);
try {
ret.put(onnxName(), map);
} catch (NoOpNameFoundException e) {
// ignore, we dont care about onnx for this set of ops
}
try {
ret.put(tensorflowName(), map);
} catch (NoOpNameFoundException e) {
// ignore
}
return ret;
}
use of org.nd4j.imports.NoOpNameFoundException in project nd4j by deeplearning4j.
the class OnnxGraphMapper method mapNodeType.
@Override
public void mapNodeType(OnnxProto3.NodeProto tfNode, ImportState<OnnxProto3.GraphProto, onnx.OnnxProto3.TypeProto.Tensor> importState) {
val differentialFunction = DifferentialFunctionClassHolder.getInstance().getOpWithOnnxName(tfNode.getOpType());
if (differentialFunction == null) {
throw new NoOpNameFoundException("No op name found " + tfNode.getOpType());
}
val diff = importState.getSameDiff();
val idx = importState.getGraph().getNodeList().indexOf(tfNode);
val name = !tfNode.getName().isEmpty() ? tfNode.getName() : String.valueOf(idx);
try {
val newInstance = differentialFunction.getClass().newInstance();
val args = new SDVariable[tfNode.getInputCount()];
newInstance.setSameDiff(importState.getSameDiff());
newInstance.initFromOnnx(tfNode, diff, getAttrMap(tfNode), importState.getGraph());
importState.getSameDiff().putFunctionForId(newInstance.getOwnName(), newInstance);
// ensure we can track node name to function instance later.
diff.setBaseNameForFunctionInstanceId(tfNode.getName(), newInstance);
diff.addVarNameForImport(tfNode.getName());
} catch (Exception e) {
e.printStackTrace();
}
}
use of org.nd4j.imports.NoOpNameFoundException in project nd4j by deeplearning4j.
the class BaseGraphMapper method opTypeForNode.
@Override
public Op.Type opTypeForNode(NODE_TYPE nodeDef) {
DifferentialFunction opWithTensorflowName = getMappedOp(getOpType(nodeDef));
if (opWithTensorflowName == null)
throw new NoOpNameFoundException("No op found with name " + getOpType(nodeDef));
Op.Type type = opWithTensorflowName.opType();
return type;
}
use of org.nd4j.imports.NoOpNameFoundException in project nd4j by deeplearning4j.
the class Dilation2D method mappingsForFunction.
@Override
public Map<String, Map<String, PropertyMapping>> mappingsForFunction() {
Map<String, Map<String, PropertyMapping>> ret = new HashMap<>();
Map<String, PropertyMapping> map = new HashMap<>();
val sameMode = PropertyMapping.builder().tfAttrName("padding").propertyNames(new String[] { "isSameMode" }).build();
val ratesMapping = PropertyMapping.builder().tfAttrName("rates").propertyNames(new String[] { "r0", "r1", "r2", "r3" }).build();
val stridesMapping = PropertyMapping.builder().tfAttrName("strides").propertyNames(new String[] { "s0", "s1", "s2", "s3" }).build();
map.put("isSameMode", sameMode);
map.put("r0", ratesMapping);
map.put("r1", ratesMapping);
map.put("r2", ratesMapping);
map.put("r3", ratesMapping);
map.put("s0", stridesMapping);
map.put("s1", stridesMapping);
map.put("s2", stridesMapping);
map.put("s3", stridesMapping);
try {
ret.put(onnxName(), map);
} catch (NoOpNameFoundException e) {
// ignore, we dont care about onnx for this set of ops
}
try {
ret.put(tensorflowName(), map);
} catch (NoOpNameFoundException e) {
throw new RuntimeException(e);
}
return ret;
}
use of org.nd4j.imports.NoOpNameFoundException in project nd4j by deeplearning4j.
the class OpsMappingTests method addOperation.
protected void addOperation(Class<? extends DifferentialFunction> clazz, List<Operation> list) {
if (Modifier.isAbstract(clazz.getModifiers()) || clazz.isInterface())
return;
try {
DifferentialFunction node = clazz.newInstance();
if (node instanceof DynamicCustomOp) {
list.add(new Operation(-1L, node.opName().toLowerCase()));
list.add(new Operation(-1L, node.tensorflowName().toLowerCase()));
} else {
val op = new Operation(Long.valueOf(node.opNum()), node.opName());
list.add(op);
}
} catch (UnsupportedOperationException e) {
//
} catch (NoOpNameFoundException e) {
//
} catch (InstantiationException e) {
//
} catch (Exception e) {
log.info("Failed on [{}]", clazz.getSimpleName());
throw new RuntimeException(e);
}
}
Aggregations