use of org.nd4j.imports.descriptors.properties.adapters.ConditionalFieldValueIntIndexArrayAdapter in project nd4j by deeplearning4j.
the class Conv2D method attributeAdaptersForFunction.
@Override
public Map<String, Map<String, AttributeAdapter>> attributeAdaptersForFunction() {
Map<String, Map<String, AttributeAdapter>> ret = new HashMap<>();
Map<String, AttributeAdapter> tfMappings = new LinkedHashMap<>();
val fields = DifferentialFunctionClassHolder.getInstance().getFieldsForFunction(this);
tfMappings.put("kh", new ConditionalFieldValueNDArrayShapeAdapter("NCHW", 2, 0, fields.get("dataFormat")));
tfMappings.put("kw", new ConditionalFieldValueNDArrayShapeAdapter("NCHW", 3, 1, fields.get("dataFormat")));
tfMappings.put("sy", new ConditionalFieldValueIntIndexArrayAdapter("NCHW", 2, 1, fields.get("dataFormat")));
tfMappings.put("sx", new ConditionalFieldValueIntIndexArrayAdapter("NCHW", 3, 2, fields.get("dataFormat")));
tfMappings.put("isSameMode", new StringEqualsAdapter("SAME"));
tfMappings.put("isNHWC", new StringEqualsAdapter("NHWC"));
Map<String, AttributeAdapter> onnxMappings = new HashMap<>();
onnxMappings.put("kh", new SizeThresholdIntArrayIntIndexAdpater(0, 2, 0));
onnxMappings.put("kw", new SizeThresholdIntArrayIntIndexAdpater(1, 2, 0));
onnxMappings.put("dh", new SizeThresholdIntArrayIntIndexAdpater(0, 2, 0));
onnxMappings.put("dw", new SizeThresholdIntArrayIntIndexAdpater(1, 2, 0));
onnxMappings.put("sy", new SizeThresholdIntArrayIntIndexAdpater(0, 2, 0));
onnxMappings.put("sx", new SizeThresholdIntArrayIntIndexAdpater(1, 2, 0));
onnxMappings.put("isSameMode", new StringEqualsAdapter("SAME"));
onnxMappings.put("isNHWC", new StringEqualsAdapter("NHWC"));
ret.put(tensorflowName(), tfMappings);
ret.put(onnxName(), onnxMappings);
return ret;
}
use of org.nd4j.imports.descriptors.properties.adapters.ConditionalFieldValueIntIndexArrayAdapter in project nd4j by deeplearning4j.
the class Conv1D method attributeAdaptersForFunction.
@Override
public Map<String, Map<String, AttributeAdapter>> attributeAdaptersForFunction() {
Map<String, Map<String, AttributeAdapter>> ret = new HashMap<>();
Map<String, AttributeAdapter> tfMappings = new LinkedHashMap<>();
val fields = DifferentialFunctionClassHolder.getInstance().getFieldsForFunction(this);
tfMappings.put("kh", new ConditionalFieldValueNDArrayShapeAdapter("NCHW", 2, 0, fields.get("dataFormat")));
tfMappings.put("kw", new ConditionalFieldValueNDArrayShapeAdapter("NCHW", 3, 1, fields.get("dataFormat")));
tfMappings.put("sy", new ConditionalFieldValueIntIndexArrayAdapter("NCHW", 2, 1, fields.get("dataFormat")));
tfMappings.put("sx", new ConditionalFieldValueIntIndexArrayAdapter("NCHW", 3, 2, fields.get("dataFormat")));
tfMappings.put("isSameMode", new StringEqualsAdapter("SAME"));
tfMappings.put("isNHWC", new StringEqualsAdapter("NHWC"));
Map<String, AttributeAdapter> onnxMappings = new HashMap<>();
onnxMappings.put("kh", new SizeThresholdIntArrayIntIndexAdpater(0, 2, 0));
onnxMappings.put("kw", new SizeThresholdIntArrayIntIndexAdpater(1, 2, 0));
onnxMappings.put("dh", new SizeThresholdIntArrayIntIndexAdpater(0, 2, 0));
onnxMappings.put("dw", new SizeThresholdIntArrayIntIndexAdpater(1, 2, 0));
onnxMappings.put("sy", new SizeThresholdIntArrayIntIndexAdpater(0, 2, 0));
onnxMappings.put("sx", new SizeThresholdIntArrayIntIndexAdpater(1, 2, 0));
onnxMappings.put("isSameMode", new StringEqualsAdapter("SAME"));
onnxMappings.put("isNHWC", new StringEqualsAdapter("NHWC"));
ret.put(tensorflowName(), tfMappings);
ret.put(onnxName(), onnxMappings);
return ret;
}
use of org.nd4j.imports.descriptors.properties.adapters.ConditionalFieldValueIntIndexArrayAdapter in project nd4j by deeplearning4j.
the class DepthwiseConv2D method attributeAdaptersForFunction.
@Override
public Map<String, Map<String, AttributeAdapter>> attributeAdaptersForFunction() {
Map<String, Map<String, AttributeAdapter>> ret = new HashMap<>();
Map<String, AttributeAdapter> tfMappings = new LinkedHashMap<>();
val fields = DifferentialFunctionClassHolder.getInstance().getFieldsForFunction(this);
tfMappings.put("kh", new ConditionalFieldValueNDArrayShapeAdapter("NCHW", 0, 0, fields.get("dataFormat")));
tfMappings.put("kw", new ConditionalFieldValueNDArrayShapeAdapter("NCHW", 1, 1, fields.get("dataFormat")));
tfMappings.put("sy", new ConditionalFieldValueIntIndexArrayAdapter("NCHW", 2, 1, fields.get("dataFormat")));
tfMappings.put("sx", new ConditionalFieldValueIntIndexArrayAdapter("NCHW", 3, 2, fields.get("dataFormat")));
tfMappings.put("isSameMode", new StringEqualsAdapter("SAME"));
tfMappings.put("isNHWC", new StringEqualsAdapter("NHWC"));
Map<String, AttributeAdapter> onnxMappings = new HashMap<>();
onnxMappings.put("kh", new SizeThresholdIntArrayIntIndexAdpater(0, 2, 0));
onnxMappings.put("kw", new SizeThresholdIntArrayIntIndexAdpater(1, 2, 0));
onnxMappings.put("dh", new SizeThresholdIntArrayIntIndexAdpater(0, 2, 0));
onnxMappings.put("dw", new SizeThresholdIntArrayIntIndexAdpater(1, 2, 0));
onnxMappings.put("sy", new SizeThresholdIntArrayIntIndexAdpater(0, 2, 0));
onnxMappings.put("sx", new SizeThresholdIntArrayIntIndexAdpater(1, 2, 0));
onnxMappings.put("isSameMode", new StringEqualsAdapter("SAME"));
onnxMappings.put("isNHWC", new StringEqualsAdapter("NHWC"));
try {
ret.put(tensorflowName(), tfMappings);
} catch (NoOpNameFoundException e) {
//
}
try {
ret.put(onnxName(), onnxMappings);
} catch (NoOpNameFoundException e) {
//
}
return ret;
}
Aggregations