Search in sources :

Example 1 with SizeThresholdIntArrayIntIndexAdpater

use of org.nd4j.imports.descriptors.properties.adapters.SizeThresholdIntArrayIntIndexAdpater in project nd4j by deeplearning4j.

the class Conv2D method attributeAdaptersForFunction.

@Override
public Map<String, Map<String, AttributeAdapter>> attributeAdaptersForFunction() {
    Map<String, Map<String, AttributeAdapter>> ret = new HashMap<>();
    Map<String, AttributeAdapter> tfMappings = new LinkedHashMap<>();
    val fields = DifferentialFunctionClassHolder.getInstance().getFieldsForFunction(this);
    tfMappings.put("kh", new ConditionalFieldValueNDArrayShapeAdapter("NCHW", 2, 0, fields.get("dataFormat")));
    tfMappings.put("kw", new ConditionalFieldValueNDArrayShapeAdapter("NCHW", 3, 1, fields.get("dataFormat")));
    tfMappings.put("sy", new ConditionalFieldValueIntIndexArrayAdapter("NCHW", 2, 1, fields.get("dataFormat")));
    tfMappings.put("sx", new ConditionalFieldValueIntIndexArrayAdapter("NCHW", 3, 2, fields.get("dataFormat")));
    tfMappings.put("isSameMode", new StringEqualsAdapter("SAME"));
    tfMappings.put("isNHWC", new StringEqualsAdapter("NHWC"));
    Map<String, AttributeAdapter> onnxMappings = new HashMap<>();
    onnxMappings.put("kh", new SizeThresholdIntArrayIntIndexAdpater(0, 2, 0));
    onnxMappings.put("kw", new SizeThresholdIntArrayIntIndexAdpater(1, 2, 0));
    onnxMappings.put("dh", new SizeThresholdIntArrayIntIndexAdpater(0, 2, 0));
    onnxMappings.put("dw", new SizeThresholdIntArrayIntIndexAdpater(1, 2, 0));
    onnxMappings.put("sy", new SizeThresholdIntArrayIntIndexAdpater(0, 2, 0));
    onnxMappings.put("sx", new SizeThresholdIntArrayIntIndexAdpater(1, 2, 0));
    onnxMappings.put("isSameMode", new StringEqualsAdapter("SAME"));
    onnxMappings.put("isNHWC", new StringEqualsAdapter("NHWC"));
    ret.put(tensorflowName(), tfMappings);
    ret.put(onnxName(), onnxMappings);
    return ret;
}
Also used : lombok.val(lombok.val) ConditionalFieldValueNDArrayShapeAdapter(org.nd4j.imports.descriptors.properties.adapters.ConditionalFieldValueNDArrayShapeAdapter) SizeThresholdIntArrayIntIndexAdpater(org.nd4j.imports.descriptors.properties.adapters.SizeThresholdIntArrayIntIndexAdpater) AttributeAdapter(org.nd4j.imports.descriptors.properties.AttributeAdapter) StringEqualsAdapter(org.nd4j.imports.descriptors.properties.adapters.StringEqualsAdapter) ConditionalFieldValueIntIndexArrayAdapter(org.nd4j.imports.descriptors.properties.adapters.ConditionalFieldValueIntIndexArrayAdapter)

Example 2 with SizeThresholdIntArrayIntIndexAdpater

use of org.nd4j.imports.descriptors.properties.adapters.SizeThresholdIntArrayIntIndexAdpater in project nd4j by deeplearning4j.

the class Conv1D method attributeAdaptersForFunction.

@Override
public Map<String, Map<String, AttributeAdapter>> attributeAdaptersForFunction() {
    Map<String, Map<String, AttributeAdapter>> ret = new HashMap<>();
    Map<String, AttributeAdapter> tfMappings = new LinkedHashMap<>();
    val fields = DifferentialFunctionClassHolder.getInstance().getFieldsForFunction(this);
    tfMappings.put("kh", new ConditionalFieldValueNDArrayShapeAdapter("NCHW", 2, 0, fields.get("dataFormat")));
    tfMappings.put("kw", new ConditionalFieldValueNDArrayShapeAdapter("NCHW", 3, 1, fields.get("dataFormat")));
    tfMappings.put("sy", new ConditionalFieldValueIntIndexArrayAdapter("NCHW", 2, 1, fields.get("dataFormat")));
    tfMappings.put("sx", new ConditionalFieldValueIntIndexArrayAdapter("NCHW", 3, 2, fields.get("dataFormat")));
    tfMappings.put("isSameMode", new StringEqualsAdapter("SAME"));
    tfMappings.put("isNHWC", new StringEqualsAdapter("NHWC"));
    Map<String, AttributeAdapter> onnxMappings = new HashMap<>();
    onnxMappings.put("kh", new SizeThresholdIntArrayIntIndexAdpater(0, 2, 0));
    onnxMappings.put("kw", new SizeThresholdIntArrayIntIndexAdpater(1, 2, 0));
    onnxMappings.put("dh", new SizeThresholdIntArrayIntIndexAdpater(0, 2, 0));
    onnxMappings.put("dw", new SizeThresholdIntArrayIntIndexAdpater(1, 2, 0));
    onnxMappings.put("sy", new SizeThresholdIntArrayIntIndexAdpater(0, 2, 0));
    onnxMappings.put("sx", new SizeThresholdIntArrayIntIndexAdpater(1, 2, 0));
    onnxMappings.put("isSameMode", new StringEqualsAdapter("SAME"));
    onnxMappings.put("isNHWC", new StringEqualsAdapter("NHWC"));
    ret.put(tensorflowName(), tfMappings);
    ret.put(onnxName(), onnxMappings);
    return ret;
}
Also used : lombok.val(lombok.val) ConditionalFieldValueNDArrayShapeAdapter(org.nd4j.imports.descriptors.properties.adapters.ConditionalFieldValueNDArrayShapeAdapter) SizeThresholdIntArrayIntIndexAdpater(org.nd4j.imports.descriptors.properties.adapters.SizeThresholdIntArrayIntIndexAdpater) AttributeAdapter(org.nd4j.imports.descriptors.properties.AttributeAdapter) StringEqualsAdapter(org.nd4j.imports.descriptors.properties.adapters.StringEqualsAdapter) ConditionalFieldValueIntIndexArrayAdapter(org.nd4j.imports.descriptors.properties.adapters.ConditionalFieldValueIntIndexArrayAdapter)

Example 3 with SizeThresholdIntArrayIntIndexAdpater

use of org.nd4j.imports.descriptors.properties.adapters.SizeThresholdIntArrayIntIndexAdpater in project nd4j by deeplearning4j.

the class DepthwiseConv2D method attributeAdaptersForFunction.

@Override
public Map<String, Map<String, AttributeAdapter>> attributeAdaptersForFunction() {
    Map<String, Map<String, AttributeAdapter>> ret = new HashMap<>();
    Map<String, AttributeAdapter> tfMappings = new LinkedHashMap<>();
    val fields = DifferentialFunctionClassHolder.getInstance().getFieldsForFunction(this);
    tfMappings.put("kh", new ConditionalFieldValueNDArrayShapeAdapter("NCHW", 0, 0, fields.get("dataFormat")));
    tfMappings.put("kw", new ConditionalFieldValueNDArrayShapeAdapter("NCHW", 1, 1, fields.get("dataFormat")));
    tfMappings.put("sy", new ConditionalFieldValueIntIndexArrayAdapter("NCHW", 2, 1, fields.get("dataFormat")));
    tfMappings.put("sx", new ConditionalFieldValueIntIndexArrayAdapter("NCHW", 3, 2, fields.get("dataFormat")));
    tfMappings.put("isSameMode", new StringEqualsAdapter("SAME"));
    tfMappings.put("isNHWC", new StringEqualsAdapter("NHWC"));
    Map<String, AttributeAdapter> onnxMappings = new HashMap<>();
    onnxMappings.put("kh", new SizeThresholdIntArrayIntIndexAdpater(0, 2, 0));
    onnxMappings.put("kw", new SizeThresholdIntArrayIntIndexAdpater(1, 2, 0));
    onnxMappings.put("dh", new SizeThresholdIntArrayIntIndexAdpater(0, 2, 0));
    onnxMappings.put("dw", new SizeThresholdIntArrayIntIndexAdpater(1, 2, 0));
    onnxMappings.put("sy", new SizeThresholdIntArrayIntIndexAdpater(0, 2, 0));
    onnxMappings.put("sx", new SizeThresholdIntArrayIntIndexAdpater(1, 2, 0));
    onnxMappings.put("isSameMode", new StringEqualsAdapter("SAME"));
    onnxMappings.put("isNHWC", new StringEqualsAdapter("NHWC"));
    try {
        ret.put(tensorflowName(), tfMappings);
    } catch (NoOpNameFoundException e) {
    // 
    }
    try {
        ret.put(onnxName(), onnxMappings);
    } catch (NoOpNameFoundException e) {
    // 
    }
    return ret;
}
Also used : lombok.val(lombok.val) ConditionalFieldValueNDArrayShapeAdapter(org.nd4j.imports.descriptors.properties.adapters.ConditionalFieldValueNDArrayShapeAdapter) StringEqualsAdapter(org.nd4j.imports.descriptors.properties.adapters.StringEqualsAdapter) SizeThresholdIntArrayIntIndexAdpater(org.nd4j.imports.descriptors.properties.adapters.SizeThresholdIntArrayIntIndexAdpater) AttributeAdapter(org.nd4j.imports.descriptors.properties.AttributeAdapter) NoOpNameFoundException(org.nd4j.imports.NoOpNameFoundException) ConditionalFieldValueIntIndexArrayAdapter(org.nd4j.imports.descriptors.properties.adapters.ConditionalFieldValueIntIndexArrayAdapter)

Aggregations

lombok.val (lombok.val)3 AttributeAdapter (org.nd4j.imports.descriptors.properties.AttributeAdapter)3 ConditionalFieldValueIntIndexArrayAdapter (org.nd4j.imports.descriptors.properties.adapters.ConditionalFieldValueIntIndexArrayAdapter)3 ConditionalFieldValueNDArrayShapeAdapter (org.nd4j.imports.descriptors.properties.adapters.ConditionalFieldValueNDArrayShapeAdapter)3 SizeThresholdIntArrayIntIndexAdpater (org.nd4j.imports.descriptors.properties.adapters.SizeThresholdIntArrayIntIndexAdpater)3 StringEqualsAdapter (org.nd4j.imports.descriptors.properties.adapters.StringEqualsAdapter)3 NoOpNameFoundException (org.nd4j.imports.NoOpNameFoundException)1