use of org.nd4j.imports.descriptors.properties.PropertyMapping in project nd4j by deeplearning4j.
the class MaxPooling2D method mappingsForFunction.
@Override
public Map<String, Map<String, PropertyMapping>> mappingsForFunction() {
Map<String, Map<String, PropertyMapping>> ret = new HashMap<>();
Map<String, PropertyMapping> map = new HashMap<>();
val strideMapping = PropertyMapping.builder().tfAttrName("strides").onnxAttrName("strides").propertyNames(new String[] { "sx", "sy" }).build();
val paddingMapping = PropertyMapping.builder().onnxAttrName("padding").tfAttrName("padding").propertyNames(new String[] { "px", "py" }).build();
val kernelMapping = PropertyMapping.builder().propertyNames(new String[] { "kh", "kw" }).tfInputPosition(1).onnxAttrName("ksize").build();
val dilationMapping = PropertyMapping.builder().onnxAttrName("dilations").propertyNames(new String[] { "dw", "dh" }).tfAttrName("rates").build();
// data_format
val dataFormatMapping = PropertyMapping.builder().propertyNames(new String[] { "isNHWC" }).tfAttrName("data_format").build();
map.put("sx", strideMapping);
map.put("sy", strideMapping);
map.put("kh", kernelMapping);
map.put("kw", kernelMapping);
map.put("dw", dilationMapping);
map.put("dh", dilationMapping);
map.put("ph", paddingMapping);
map.put("pw", paddingMapping);
map.put("isNHWC", dataFormatMapping);
ret.put(onnxName(), map);
ret.put(tensorflowName(), map);
return ret;
}
use of org.nd4j.imports.descriptors.properties.PropertyMapping in project nd4j by deeplearning4j.
the class AvgPooling2D method mappingsForFunction.
@Override
public Map<String, Map<String, PropertyMapping>> mappingsForFunction() {
Map<String, Map<String, PropertyMapping>> ret = new HashMap<>();
Map<String, PropertyMapping> map = new HashMap<>();
val strideMapping = PropertyMapping.builder().tfAttrName("strides").onnxAttrName("strides").propertyNames(new String[] { "sx", "sy" }).build();
val paddingMapping = PropertyMapping.builder().onnxAttrName("padding").tfAttrName("padding").propertyNames(new String[] { "px", "py" }).build();
val kernelMapping = PropertyMapping.builder().propertyNames(new String[] { "kh", "kw" }).tfInputPosition(1).onnxAttrName("ksize").build();
val dilationMapping = PropertyMapping.builder().onnxAttrName("dilations").propertyNames(new String[] { "dw", "dh" }).tfAttrName("rates").build();
// data_format
val dataFormatMapping = PropertyMapping.builder().propertyNames(new String[] { "isNHWC" }).tfAttrName("data_format").build();
map.put("sx", strideMapping);
map.put("sy", strideMapping);
map.put("kh", kernelMapping);
map.put("kw", kernelMapping);
map.put("dw", dilationMapping);
map.put("dh", dilationMapping);
map.put("ph", paddingMapping);
map.put("pw", paddingMapping);
map.put("isNHWC", dataFormatMapping);
ret.put(onnxName(), map);
ret.put(tensorflowName(), map);
return ret;
}
use of org.nd4j.imports.descriptors.properties.PropertyMapping in project nd4j by deeplearning4j.
the class BaseCompatOp method mappingsForFunction.
@Override
public Map<String, Map<String, PropertyMapping>> mappingsForFunction() {
Map<String, Map<String, PropertyMapping>> ret = new HashMap<>();
Map<String, PropertyMapping> map = new HashMap<>();
val frameNameMapping = PropertyMapping.builder().tfAttrName("frame_name").onnxAttrName(// not sure if it exists in onnx
"frame_name").propertyNames(new String[] { "frameName" }).build();
map.put("frameName", frameNameMapping);
try {
ret.put(onnxName(), map);
} catch (NoOpNameFoundException e) {
// ignore, we dont care about onnx for this set of ops
}
try {
ret.put(tensorflowName(), map);
} catch (NoOpNameFoundException e) {
// ignore
}
return ret;
}
use of org.nd4j.imports.descriptors.properties.PropertyMapping in project nd4j by deeplearning4j.
the class CumProd method mappingsForFunction.
@Override
public Map<String, Map<String, PropertyMapping>> mappingsForFunction() {
Map<String, Map<String, PropertyMapping>> ret = new HashMap<>();
Map<String, PropertyMapping> map = new HashMap<>();
val exclusiveMapper = PropertyMapping.builder().tfAttrName("exclusive").propertyNames(new String[] { "exclusive" }).build();
val reverseMapper = PropertyMapping.builder().tfAttrName("reverse").propertyNames(new String[] { "reverse" }).build();
map.put("exclusive", exclusiveMapper);
map.put("reverse", reverseMapper);
ret.put(tensorflowName(), map);
return ret;
}
use of org.nd4j.imports.descriptors.properties.PropertyMapping in project nd4j by deeplearning4j.
the class Split method mappingsForFunction.
@Override
public Map<String, Map<String, PropertyMapping>> mappingsForFunction() {
Map<String, Map<String, PropertyMapping>> ret = new HashMap<>();
Map<String, PropertyMapping> map = new HashMap<>();
val splitDim = PropertyMapping.builder().tfInputPosition(0).propertyNames(new String[] { "splitDim" }).build();
val numSplit = PropertyMapping.builder().tfAttrName("num_split").propertyNames(new String[] { "numSplit" }).build();
map.put("numSplit", numSplit);
map.put("splitDim", splitDim);
ret.put(tensorflowName(), map);
return ret;
}
Aggregations