Search in sources :

Example 6 with NoOpNameFoundException

use of org.nd4j.imports.NoOpNameFoundException in project nd4j by deeplearning4j.

the class Conv2D method mappingsForFunction.

@Override
public Map<String, Map<String, PropertyMapping>> mappingsForFunction() {
    Map<String, Map<String, PropertyMapping>> ret = new HashMap<>();
    Map<String, PropertyMapping> map = new HashMap<>();
    val strideMapping = PropertyMapping.builder().tfAttrName("strides").onnxAttrName("strides").propertyNames(new String[] { "sx", "sy" }).build();
    val kernelMappingH = PropertyMapping.builder().propertyNames(new String[] { "kh" }).tfInputPosition(1).shapePosition(0).onnxAttrName("kernel_shape").build();
    val kernelMappingW = PropertyMapping.builder().propertyNames(new String[] { "kw" }).tfInputPosition(1).shapePosition(1).onnxAttrName("kernel_shape").build();
    val dilationMapping = PropertyMapping.builder().onnxAttrName("dilations").propertyNames(new String[] { "dw", "dh" }).tfAttrName("rates").build();
    val dataFormat = PropertyMapping.builder().onnxAttrName("data_format").tfAttrName("data_format").propertyNames(new String[] { "dataFormat" }).build();
    val nhwc = PropertyMapping.builder().onnxAttrName("data_format").tfAttrName("data_format").propertyNames(new String[] { "isNHWC" }).build();
    val sameMode = PropertyMapping.builder().onnxAttrName("auto_pad").propertyNames(new String[] { "isSameMode" }).tfAttrName("padding").build();
    val paddingWidthHeight = PropertyMapping.builder().onnxAttrName("padding").propertyNames(new String[] { "ph", "pw" }).build();
    map.put("sx", strideMapping);
    map.put("sy", strideMapping);
    map.put("kh", kernelMappingH);
    map.put("kw", kernelMappingW);
    map.put("dw", dilationMapping);
    map.put("dh", dilationMapping);
    map.put("isSameMode", sameMode);
    map.put("ph", paddingWidthHeight);
    map.put("pw", paddingWidthHeight);
    map.put("dataFormat", dataFormat);
    map.put("isNHWC", nhwc);
    try {
        ret.put(onnxName(), map);
    } catch (NoOpNameFoundException e) {
    // ignore
    }
    try {
        ret.put(tensorflowName(), map);
    } catch (NoOpNameFoundException e) {
    // ignore
    }
    return ret;
}
Also used : lombok.val(lombok.val) PropertyMapping(org.nd4j.imports.descriptors.properties.PropertyMapping) NoOpNameFoundException(org.nd4j.imports.NoOpNameFoundException)

Example 7 with NoOpNameFoundException

use of org.nd4j.imports.NoOpNameFoundException in project nd4j by deeplearning4j.

the class DepthwiseConv2D method mappingsForFunction.

@Override
public Map<String, Map<String, PropertyMapping>> mappingsForFunction() {
    Map<String, Map<String, PropertyMapping>> ret = new HashMap<>();
    Map<String, PropertyMapping> map = new HashMap<>();
    val strideMapping = PropertyMapping.builder().tfAttrName("strides").onnxAttrName("strides").propertyNames(new String[] { "sx", "sy" }).build();
    val kernelMappingH = PropertyMapping.builder().propertyNames(new String[] { "kh" }).tfInputPosition(1).shapePosition(0).onnxAttrName("kernel_shape").build();
    val kernelMappingW = PropertyMapping.builder().propertyNames(new String[] { "kw" }).tfInputPosition(1).shapePosition(1).onnxAttrName("kernel_shape").build();
    val dilationMapping = PropertyMapping.builder().onnxAttrName("dilations").propertyNames(new String[] { "dw", "dh" }).tfAttrName("rates").build();
    val dataFormat = PropertyMapping.builder().onnxAttrName("data_format").tfAttrName("data_format").propertyNames(new String[] { "dataFormat" }).build();
    val nhwc = PropertyMapping.builder().onnxAttrName("data_format").tfAttrName("data_format").propertyNames(new String[] { "isNHWC" }).build();
    val sameMode = PropertyMapping.builder().onnxAttrName("auto_pad").propertyNames(new String[] { "isSameMode" }).tfAttrName("padding").build();
    val paddingWidthHeight = PropertyMapping.builder().onnxAttrName("padding").propertyNames(new String[] { "ph", "pw" }).build();
    map.put("sx", strideMapping);
    map.put("sy", strideMapping);
    map.put("kh", kernelMappingH);
    map.put("kw", kernelMappingW);
    map.put("dw", dilationMapping);
    map.put("dh", dilationMapping);
    map.put("isSameMode", sameMode);
    map.put("ph", paddingWidthHeight);
    map.put("pw", paddingWidthHeight);
    map.put("dataFormat", dataFormat);
    map.put("isNHWC", nhwc);
    try {
        ret.put(onnxName(), map);
    } catch (NoOpNameFoundException e) {
    // ignore
    }
    try {
        ret.put(tensorflowName(), map);
    } catch (NoOpNameFoundException e) {
    // ignore
    }
    return ret;
}
Also used : lombok.val(lombok.val) PropertyMapping(org.nd4j.imports.descriptors.properties.PropertyMapping) NoOpNameFoundException(org.nd4j.imports.NoOpNameFoundException)

Example 8 with NoOpNameFoundException

use of org.nd4j.imports.NoOpNameFoundException in project nd4j by deeplearning4j.

the class DepthwiseConv2D method attributeAdaptersForFunction.

@Override
public Map<String, Map<String, AttributeAdapter>> attributeAdaptersForFunction() {
    Map<String, Map<String, AttributeAdapter>> ret = new HashMap<>();
    Map<String, AttributeAdapter> tfMappings = new LinkedHashMap<>();
    val fields = DifferentialFunctionClassHolder.getInstance().getFieldsForFunction(this);
    tfMappings.put("kh", new ConditionalFieldValueNDArrayShapeAdapter("NCHW", 0, 0, fields.get("dataFormat")));
    tfMappings.put("kw", new ConditionalFieldValueNDArrayShapeAdapter("NCHW", 1, 1, fields.get("dataFormat")));
    tfMappings.put("sy", new ConditionalFieldValueIntIndexArrayAdapter("NCHW", 2, 1, fields.get("dataFormat")));
    tfMappings.put("sx", new ConditionalFieldValueIntIndexArrayAdapter("NCHW", 3, 2, fields.get("dataFormat")));
    tfMappings.put("isSameMode", new StringEqualsAdapter("SAME"));
    tfMappings.put("isNHWC", new StringEqualsAdapter("NHWC"));
    Map<String, AttributeAdapter> onnxMappings = new HashMap<>();
    onnxMappings.put("kh", new SizeThresholdIntArrayIntIndexAdpater(0, 2, 0));
    onnxMappings.put("kw", new SizeThresholdIntArrayIntIndexAdpater(1, 2, 0));
    onnxMappings.put("dh", new SizeThresholdIntArrayIntIndexAdpater(0, 2, 0));
    onnxMappings.put("dw", new SizeThresholdIntArrayIntIndexAdpater(1, 2, 0));
    onnxMappings.put("sy", new SizeThresholdIntArrayIntIndexAdpater(0, 2, 0));
    onnxMappings.put("sx", new SizeThresholdIntArrayIntIndexAdpater(1, 2, 0));
    onnxMappings.put("isSameMode", new StringEqualsAdapter("SAME"));
    onnxMappings.put("isNHWC", new StringEqualsAdapter("NHWC"));
    try {
        ret.put(tensorflowName(), tfMappings);
    } catch (NoOpNameFoundException e) {
    // 
    }
    try {
        ret.put(onnxName(), onnxMappings);
    } catch (NoOpNameFoundException e) {
    // 
    }
    return ret;
}
Also used : lombok.val(lombok.val) ConditionalFieldValueNDArrayShapeAdapter(org.nd4j.imports.descriptors.properties.adapters.ConditionalFieldValueNDArrayShapeAdapter) StringEqualsAdapter(org.nd4j.imports.descriptors.properties.adapters.StringEqualsAdapter) SizeThresholdIntArrayIntIndexAdpater(org.nd4j.imports.descriptors.properties.adapters.SizeThresholdIntArrayIntIndexAdpater) AttributeAdapter(org.nd4j.imports.descriptors.properties.AttributeAdapter) NoOpNameFoundException(org.nd4j.imports.NoOpNameFoundException) ConditionalFieldValueIntIndexArrayAdapter(org.nd4j.imports.descriptors.properties.adapters.ConditionalFieldValueIntIndexArrayAdapter)

Example 9 with NoOpNameFoundException

use of org.nd4j.imports.NoOpNameFoundException in project nd4j by deeplearning4j.

the class Conv1D method mappingsForFunction.

@Override
public Map<String, Map<String, PropertyMapping>> mappingsForFunction() {
    Map<String, Map<String, PropertyMapping>> ret = new HashMap<>();
    Map<String, PropertyMapping> map = new HashMap<>();
    val strideMapping = PropertyMapping.builder().tfAttrName("strides").onnxAttrName("strides").propertyNames(new String[] { "s" }).build();
    val kernelMapping = PropertyMapping.builder().propertyNames(new String[] { "k" }).tfInputPosition(1).shapePosition(0).onnxAttrName("kernel_shape").build();
    val paddingMapping = PropertyMapping.builder().onnxAttrName("padding").propertyNames(new String[] { "p" }).build();
    val dataFormat = PropertyMapping.builder().onnxAttrName("data_format").tfAttrName("data_format").propertyNames(new String[] { "dataFormat" }).build();
    val nhwc = PropertyMapping.builder().onnxAttrName("data_format").tfAttrName("data_format").propertyNames(new String[] { "isNHWC" }).build();
    val sameMode = PropertyMapping.builder().onnxAttrName("auto_pad").propertyNames(new String[] { "isSameMode" }).tfAttrName("padding").build();
    map.put("s", strideMapping);
    map.put("k", kernelMapping);
    map.put("p", paddingMapping);
    map.put("isSameMode", sameMode);
    map.put("dataFormat", dataFormat);
    map.put("isNHWC", nhwc);
    try {
        ret.put(onnxName(), map);
    } catch (NoOpNameFoundException e) {
    // ignore
    }
    try {
        ret.put(tensorflowName(), map);
    } catch (NoOpNameFoundException e) {
    // ignore
    }
    return ret;
}
Also used : lombok.val(lombok.val) PropertyMapping(org.nd4j.imports.descriptors.properties.PropertyMapping) NoOpNameFoundException(org.nd4j.imports.NoOpNameFoundException)

Aggregations

NoOpNameFoundException (org.nd4j.imports.NoOpNameFoundException)9 lombok.val (lombok.val)8 PropertyMapping (org.nd4j.imports.descriptors.properties.PropertyMapping)5 HashMap (java.util.HashMap)2 Map (java.util.Map)2 DifferentialFunction (org.nd4j.autodiff.functions.DifferentialFunction)2 ND4JIllegalStateException (org.nd4j.linalg.exception.ND4JIllegalStateException)2 LinkedHashMap (java.util.LinkedHashMap)1 SDVariable (org.nd4j.autodiff.samediff.SDVariable)1 AttributeAdapter (org.nd4j.imports.descriptors.properties.AttributeAdapter)1 ConditionalFieldValueIntIndexArrayAdapter (org.nd4j.imports.descriptors.properties.adapters.ConditionalFieldValueIntIndexArrayAdapter)1 ConditionalFieldValueNDArrayShapeAdapter (org.nd4j.imports.descriptors.properties.adapters.ConditionalFieldValueNDArrayShapeAdapter)1 SizeThresholdIntArrayIntIndexAdpater (org.nd4j.imports.descriptors.properties.adapters.SizeThresholdIntArrayIntIndexAdpater)1 StringEqualsAdapter (org.nd4j.imports.descriptors.properties.adapters.StringEqualsAdapter)1 Op (org.nd4j.linalg.api.ops.Op)1