use of org.nd4j.imports.descriptors.properties.AttributeAdapter in project nd4j by deeplearning4j.
the class Dilation2D method attributeAdaptersForFunction.
@Override
public Map<String, Map<String, AttributeAdapter>> attributeAdaptersForFunction() {
Map<String, Map<String, AttributeAdapter>> ret = new HashMap<>();
Map<String, AttributeAdapter> tfMappings = new LinkedHashMap<>();
val fields = DifferentialFunctionClassHolder.getInstance().getFieldsForFunction(this);
tfMappings.put("r0", new IntArrayIntIndexAdpater(0));
tfMappings.put("r1", new IntArrayIntIndexAdpater(1));
tfMappings.put("r2", new IntArrayIntIndexAdpater(2));
tfMappings.put("r3", new IntArrayIntIndexAdpater(3));
tfMappings.put("s0", new IntArrayIntIndexAdpater(0));
tfMappings.put("s1", new IntArrayIntIndexAdpater(1));
tfMappings.put("s2", new IntArrayIntIndexAdpater(2));
tfMappings.put("s3", new IntArrayIntIndexAdpater(3));
tfMappings.put("isSameMode", new StringEqualsAdapter("SAME"));
// Onnx doesn't have this op i think?
Map<String, AttributeAdapter> onnxMappings = new HashMap<>();
onnxMappings.put("isSameMode", new StringEqualsAdapter("SAME"));
ret.put(tensorflowName(), tfMappings);
ret.put(onnxName(), onnxMappings);
return ret;
}
use of org.nd4j.imports.descriptors.properties.AttributeAdapter in project nd4j by deeplearning4j.
the class Cast method attributeAdaptersForFunction.
@Override
public Map<String, Map<String, AttributeAdapter>> attributeAdaptersForFunction() {
Map<String, Map<String, AttributeAdapter>> ret = new LinkedHashMap<>();
Map<String, AttributeAdapter> tfAdapters = new LinkedHashMap<>();
val fields = DifferentialFunctionClassHolder.getInstance().getFieldsForFunction(this);
tfAdapters.put("typeDst", new DataTypeAdapter());
ret.put(tensorflowName(), tfAdapters);
return ret;
}
use of org.nd4j.imports.descriptors.properties.AttributeAdapter in project nd4j by deeplearning4j.
the class FullConv3D method attributeAdaptersForFunction.
@Override
public Map<String, Map<String, AttributeAdapter>> attributeAdaptersForFunction() {
Map<String, Map<String, AttributeAdapter>> ret = new LinkedHashMap<>();
Map<String, AttributeAdapter> tfAdapters = new LinkedHashMap<>();
tfAdapters.put("dT", new IntArrayIntIndexAdpater(1));
tfAdapters.put("dW", new IntArrayIntIndexAdpater(2));
tfAdapters.put("dH", new IntArrayIntIndexAdpater(3));
tfAdapters.put("pT", new IntArrayIntIndexAdpater(1));
tfAdapters.put("pW", new IntArrayIntIndexAdpater(2));
tfAdapters.put("pH", new IntArrayIntIndexAdpater(3));
ret.put(tensorflowName(), tfAdapters);
return ret;
}
use of org.nd4j.imports.descriptors.properties.AttributeAdapter in project nd4j by deeplearning4j.
the class DepthwiseConv2D method attributeAdaptersForFunction.
@Override
public Map<String, Map<String, AttributeAdapter>> attributeAdaptersForFunction() {
Map<String, Map<String, AttributeAdapter>> ret = new HashMap<>();
Map<String, AttributeAdapter> tfMappings = new LinkedHashMap<>();
val fields = DifferentialFunctionClassHolder.getInstance().getFieldsForFunction(this);
tfMappings.put("kh", new ConditionalFieldValueNDArrayShapeAdapter("NCHW", 0, 0, fields.get("dataFormat")));
tfMappings.put("kw", new ConditionalFieldValueNDArrayShapeAdapter("NCHW", 1, 1, fields.get("dataFormat")));
tfMappings.put("sy", new ConditionalFieldValueIntIndexArrayAdapter("NCHW", 2, 1, fields.get("dataFormat")));
tfMappings.put("sx", new ConditionalFieldValueIntIndexArrayAdapter("NCHW", 3, 2, fields.get("dataFormat")));
tfMappings.put("isSameMode", new StringEqualsAdapter("SAME"));
tfMappings.put("isNHWC", new StringEqualsAdapter("NHWC"));
Map<String, AttributeAdapter> onnxMappings = new HashMap<>();
onnxMappings.put("kh", new SizeThresholdIntArrayIntIndexAdpater(0, 2, 0));
onnxMappings.put("kw", new SizeThresholdIntArrayIntIndexAdpater(1, 2, 0));
onnxMappings.put("dh", new SizeThresholdIntArrayIntIndexAdpater(0, 2, 0));
onnxMappings.put("dw", new SizeThresholdIntArrayIntIndexAdpater(1, 2, 0));
onnxMappings.put("sy", new SizeThresholdIntArrayIntIndexAdpater(0, 2, 0));
onnxMappings.put("sx", new SizeThresholdIntArrayIntIndexAdpater(1, 2, 0));
onnxMappings.put("isSameMode", new StringEqualsAdapter("SAME"));
onnxMappings.put("isNHWC", new StringEqualsAdapter("NHWC"));
try {
ret.put(tensorflowName(), tfMappings);
} catch (NoOpNameFoundException e) {
//
}
try {
ret.put(onnxName(), onnxMappings);
} catch (NoOpNameFoundException e) {
//
}
return ret;
}
use of org.nd4j.imports.descriptors.properties.AttributeAdapter in project nd4j by deeplearning4j.
the class OnnxGraphMapper method initFunctionFromProperties.
/**
* Init a function's attributes
* @param mappedTfName the onnx name to pick (sometimes ops have multiple names
* @param on the function to map
* @param attributesForNode the attributes for the node
* @param node
* @param graph
*/
public void initFunctionFromProperties(String mappedTfName, DifferentialFunction on, Map<String, OnnxProto3.AttributeProto> attributesForNode, OnnxProto3.NodeProto node, OnnxProto3.GraphProto graph) {
val properties = on.mappingsForFunction();
val tfProperties = properties.get(mappedTfName);
val fields = DifferentialFunctionClassHolder.getInstance().getFieldsForFunction(on);
val attributeAdapters = on.attributeAdaptersForFunction();
for (val entry : tfProperties.entrySet()) {
val tfAttrName = entry.getValue().getTfAttrName();
val currentField = fields.get(entry.getKey());
AttributeAdapter adapter = null;
if (tfAttrName != null) {
if (currentField == null) {
continue;
}
if (attributeAdapters != null && !attributeAdapters.isEmpty()) {
val mappers = attributeAdapters.get(on.tensorflowName());
val adapterFor = mappers.get(entry.getKey());
adapter = adapterFor;
}
if (attributesForNode.containsKey(tfAttrName)) {
val attr = attributesForNode.get(tfAttrName);
switch(attr.getType()) {
case STRING:
val setString = attr.getS().toStringUtf8();
if (adapter != null) {
adapter.mapAttributeFor(setString, currentField, on);
} else
on.setValueFor(currentField, setString);
break;
case INT:
val setInt = (int) attr.getI();
if (adapter != null) {
adapter.mapAttributeFor(setInt, currentField, on);
} else
on.setValueFor(currentField, setInt);
break;
case INTS:
val setList = attr.getIntsList();
if (!setList.isEmpty()) {
val intList = Ints.toArray(setList);
if (adapter != null) {
adapter.mapAttributeFor(intList, currentField, on);
} else
on.setValueFor(currentField, intList);
}
break;
case FLOATS:
val floatsList = attr.getFloatsList();
if (!floatsList.isEmpty()) {
val floats = Floats.toArray(floatsList);
if (adapter != null) {
adapter.mapAttributeFor(floats, currentField, on);
} else
on.setValueFor(currentField, floats);
break;
}
break;
case TENSOR:
val tensorToGet = mapTensorProto(attr.getT());
if (adapter != null) {
adapter.mapAttributeFor(tensorToGet, currentField, on);
} else
on.setValueFor(currentField, tensorToGet);
break;
}
}
}
}
}
Aggregations