Search in sources :

Example 6 with AttributeAdapter

use of org.nd4j.imports.descriptors.properties.AttributeAdapter in project nd4j by deeplearning4j.

the class Dilation2D method attributeAdaptersForFunction.

@Override
public Map<String, Map<String, AttributeAdapter>> attributeAdaptersForFunction() {
    Map<String, Map<String, AttributeAdapter>> ret = new HashMap<>();
    Map<String, AttributeAdapter> tfMappings = new LinkedHashMap<>();
    val fields = DifferentialFunctionClassHolder.getInstance().getFieldsForFunction(this);
    tfMappings.put("r0", new IntArrayIntIndexAdpater(0));
    tfMappings.put("r1", new IntArrayIntIndexAdpater(1));
    tfMappings.put("r2", new IntArrayIntIndexAdpater(2));
    tfMappings.put("r3", new IntArrayIntIndexAdpater(3));
    tfMappings.put("s0", new IntArrayIntIndexAdpater(0));
    tfMappings.put("s1", new IntArrayIntIndexAdpater(1));
    tfMappings.put("s2", new IntArrayIntIndexAdpater(2));
    tfMappings.put("s3", new IntArrayIntIndexAdpater(3));
    tfMappings.put("isSameMode", new StringEqualsAdapter("SAME"));
    // Onnx doesn't have this op i think?
    Map<String, AttributeAdapter> onnxMappings = new HashMap<>();
    onnxMappings.put("isSameMode", new StringEqualsAdapter("SAME"));
    ret.put(tensorflowName(), tfMappings);
    ret.put(onnxName(), onnxMappings);
    return ret;
}
Also used : lombok.val(lombok.val) HashMap(java.util.HashMap) LinkedHashMap(java.util.LinkedHashMap) AttributeAdapter(org.nd4j.imports.descriptors.properties.AttributeAdapter) HashMap(java.util.HashMap) LinkedHashMap(java.util.LinkedHashMap) Map(java.util.Map) LinkedHashMap(java.util.LinkedHashMap)

Example 7 with AttributeAdapter

use of org.nd4j.imports.descriptors.properties.AttributeAdapter in project nd4j by deeplearning4j.

the class Cast method attributeAdaptersForFunction.

@Override
public Map<String, Map<String, AttributeAdapter>> attributeAdaptersForFunction() {
    Map<String, Map<String, AttributeAdapter>> ret = new LinkedHashMap<>();
    Map<String, AttributeAdapter> tfAdapters = new LinkedHashMap<>();
    val fields = DifferentialFunctionClassHolder.getInstance().getFieldsForFunction(this);
    tfAdapters.put("typeDst", new DataTypeAdapter());
    ret.put(tensorflowName(), tfAdapters);
    return ret;
}
Also used : lombok.val(lombok.val) AttributeAdapter(org.nd4j.imports.descriptors.properties.AttributeAdapter) DataTypeAdapter(org.nd4j.imports.descriptors.properties.adapters.DataTypeAdapter)

Example 8 with AttributeAdapter

use of org.nd4j.imports.descriptors.properties.AttributeAdapter in project nd4j by deeplearning4j.

the class FullConv3D method attributeAdaptersForFunction.

@Override
public Map<String, Map<String, AttributeAdapter>> attributeAdaptersForFunction() {
    Map<String, Map<String, AttributeAdapter>> ret = new LinkedHashMap<>();
    Map<String, AttributeAdapter> tfAdapters = new LinkedHashMap<>();
    tfAdapters.put("dT", new IntArrayIntIndexAdpater(1));
    tfAdapters.put("dW", new IntArrayIntIndexAdpater(2));
    tfAdapters.put("dH", new IntArrayIntIndexAdpater(3));
    tfAdapters.put("pT", new IntArrayIntIndexAdpater(1));
    tfAdapters.put("pW", new IntArrayIntIndexAdpater(2));
    tfAdapters.put("pH", new IntArrayIntIndexAdpater(3));
    ret.put(tensorflowName(), tfAdapters);
    return ret;
}
Also used : AttributeAdapter(org.nd4j.imports.descriptors.properties.AttributeAdapter) IntArrayIntIndexAdpater(org.nd4j.imports.descriptors.properties.adapters.IntArrayIntIndexAdpater)

Example 9 with AttributeAdapter

use of org.nd4j.imports.descriptors.properties.AttributeAdapter in project nd4j by deeplearning4j.

the class DepthwiseConv2D method attributeAdaptersForFunction.

@Override
public Map<String, Map<String, AttributeAdapter>> attributeAdaptersForFunction() {
    Map<String, Map<String, AttributeAdapter>> ret = new HashMap<>();
    Map<String, AttributeAdapter> tfMappings = new LinkedHashMap<>();
    val fields = DifferentialFunctionClassHolder.getInstance().getFieldsForFunction(this);
    tfMappings.put("kh", new ConditionalFieldValueNDArrayShapeAdapter("NCHW", 0, 0, fields.get("dataFormat")));
    tfMappings.put("kw", new ConditionalFieldValueNDArrayShapeAdapter("NCHW", 1, 1, fields.get("dataFormat")));
    tfMappings.put("sy", new ConditionalFieldValueIntIndexArrayAdapter("NCHW", 2, 1, fields.get("dataFormat")));
    tfMappings.put("sx", new ConditionalFieldValueIntIndexArrayAdapter("NCHW", 3, 2, fields.get("dataFormat")));
    tfMappings.put("isSameMode", new StringEqualsAdapter("SAME"));
    tfMappings.put("isNHWC", new StringEqualsAdapter("NHWC"));
    Map<String, AttributeAdapter> onnxMappings = new HashMap<>();
    onnxMappings.put("kh", new SizeThresholdIntArrayIntIndexAdpater(0, 2, 0));
    onnxMappings.put("kw", new SizeThresholdIntArrayIntIndexAdpater(1, 2, 0));
    onnxMappings.put("dh", new SizeThresholdIntArrayIntIndexAdpater(0, 2, 0));
    onnxMappings.put("dw", new SizeThresholdIntArrayIntIndexAdpater(1, 2, 0));
    onnxMappings.put("sy", new SizeThresholdIntArrayIntIndexAdpater(0, 2, 0));
    onnxMappings.put("sx", new SizeThresholdIntArrayIntIndexAdpater(1, 2, 0));
    onnxMappings.put("isSameMode", new StringEqualsAdapter("SAME"));
    onnxMappings.put("isNHWC", new StringEqualsAdapter("NHWC"));
    try {
        ret.put(tensorflowName(), tfMappings);
    } catch (NoOpNameFoundException e) {
    // 
    }
    try {
        ret.put(onnxName(), onnxMappings);
    } catch (NoOpNameFoundException e) {
    // 
    }
    return ret;
}
Also used : lombok.val(lombok.val) ConditionalFieldValueNDArrayShapeAdapter(org.nd4j.imports.descriptors.properties.adapters.ConditionalFieldValueNDArrayShapeAdapter) StringEqualsAdapter(org.nd4j.imports.descriptors.properties.adapters.StringEqualsAdapter) SizeThresholdIntArrayIntIndexAdpater(org.nd4j.imports.descriptors.properties.adapters.SizeThresholdIntArrayIntIndexAdpater) AttributeAdapter(org.nd4j.imports.descriptors.properties.AttributeAdapter) NoOpNameFoundException(org.nd4j.imports.NoOpNameFoundException) ConditionalFieldValueIntIndexArrayAdapter(org.nd4j.imports.descriptors.properties.adapters.ConditionalFieldValueIntIndexArrayAdapter)

Example 10 with AttributeAdapter

use of org.nd4j.imports.descriptors.properties.AttributeAdapter in project nd4j by deeplearning4j.

the class OnnxGraphMapper method initFunctionFromProperties.

/**
 * Init a function's attributes
 * @param mappedTfName the onnx name to pick (sometimes ops have multiple names
 * @param on the function to map
 * @param attributesForNode the attributes for the node
 * @param node
 * @param graph
 */
public void initFunctionFromProperties(String mappedTfName, DifferentialFunction on, Map<String, OnnxProto3.AttributeProto> attributesForNode, OnnxProto3.NodeProto node, OnnxProto3.GraphProto graph) {
    val properties = on.mappingsForFunction();
    val tfProperties = properties.get(mappedTfName);
    val fields = DifferentialFunctionClassHolder.getInstance().getFieldsForFunction(on);
    val attributeAdapters = on.attributeAdaptersForFunction();
    for (val entry : tfProperties.entrySet()) {
        val tfAttrName = entry.getValue().getTfAttrName();
        val currentField = fields.get(entry.getKey());
        AttributeAdapter adapter = null;
        if (tfAttrName != null) {
            if (currentField == null) {
                continue;
            }
            if (attributeAdapters != null && !attributeAdapters.isEmpty()) {
                val mappers = attributeAdapters.get(on.tensorflowName());
                val adapterFor = mappers.get(entry.getKey());
                adapter = adapterFor;
            }
            if (attributesForNode.containsKey(tfAttrName)) {
                val attr = attributesForNode.get(tfAttrName);
                switch(attr.getType()) {
                    case STRING:
                        val setString = attr.getS().toStringUtf8();
                        if (adapter != null) {
                            adapter.mapAttributeFor(setString, currentField, on);
                        } else
                            on.setValueFor(currentField, setString);
                        break;
                    case INT:
                        val setInt = (int) attr.getI();
                        if (adapter != null) {
                            adapter.mapAttributeFor(setInt, currentField, on);
                        } else
                            on.setValueFor(currentField, setInt);
                        break;
                    case INTS:
                        val setList = attr.getIntsList();
                        if (!setList.isEmpty()) {
                            val intList = Ints.toArray(setList);
                            if (adapter != null) {
                                adapter.mapAttributeFor(intList, currentField, on);
                            } else
                                on.setValueFor(currentField, intList);
                        }
                        break;
                    case FLOATS:
                        val floatsList = attr.getFloatsList();
                        if (!floatsList.isEmpty()) {
                            val floats = Floats.toArray(floatsList);
                            if (adapter != null) {
                                adapter.mapAttributeFor(floats, currentField, on);
                            } else
                                on.setValueFor(currentField, floats);
                            break;
                        }
                        break;
                    case TENSOR:
                        val tensorToGet = mapTensorProto(attr.getT());
                        if (adapter != null) {
                            adapter.mapAttributeFor(tensorToGet, currentField, on);
                        } else
                            on.setValueFor(currentField, tensorToGet);
                        break;
                }
            }
        }
    }
}
Also used : lombok.val(lombok.val) AttributeAdapter(org.nd4j.imports.descriptors.properties.AttributeAdapter)

Aggregations

AttributeAdapter (org.nd4j.imports.descriptors.properties.AttributeAdapter)11 lombok.val (lombok.val)8 ConditionalFieldValueNDArrayShapeAdapter (org.nd4j.imports.descriptors.properties.adapters.ConditionalFieldValueNDArrayShapeAdapter)4 StringEqualsAdapter (org.nd4j.imports.descriptors.properties.adapters.StringEqualsAdapter)4 ConditionalFieldValueIntIndexArrayAdapter (org.nd4j.imports.descriptors.properties.adapters.ConditionalFieldValueIntIndexArrayAdapter)3 SizeThresholdIntArrayIntIndexAdpater (org.nd4j.imports.descriptors.properties.adapters.SizeThresholdIntArrayIntIndexAdpater)3 BooleanAdapter (org.nd4j.imports.descriptors.properties.adapters.BooleanAdapter)2 IntArrayIntIndexAdpater (org.nd4j.imports.descriptors.properties.adapters.IntArrayIntIndexAdpater)2 HashMap (java.util.HashMap)1 LinkedHashMap (java.util.LinkedHashMap)1 Map (java.util.Map)1 NoOpNameFoundException (org.nd4j.imports.NoOpNameFoundException)1 DataTypeAdapter (org.nd4j.imports.descriptors.properties.adapters.DataTypeAdapter)1 INDArray (org.nd4j.linalg.api.ndarray.INDArray)1