use of org.nd4j.imports.descriptors.properties.PropertyMapping in project nd4j by deeplearning4j.
the class Histogram method mappingsForFunction.
@Override
public Map<String, Map<String, PropertyMapping>> mappingsForFunction() {
Map<String, Map<String, PropertyMapping>> ret = new HashMap<>();
val binMapping = PropertyMapping.builder().tfInputPosition(-1).propertyNames(new String[] { "numBins" }).build();
Map<String, PropertyMapping> map = new HashMap<>();
map.put("numBins", binMapping);
ret.put(tensorflowName(), map);
return ret;
}
use of org.nd4j.imports.descriptors.properties.PropertyMapping in project nd4j by deeplearning4j.
the class Repeat method mappingsForFunction.
@Override
public Map<String, Map<String, PropertyMapping>> mappingsForFunction() {
Map<String, Map<String, PropertyMapping>> ret = new HashMap<>();
Map<String, PropertyMapping> map = new HashMap<>();
val axisMapping = PropertyMapping.builder().onnxAttrName("axis").tfInputPosition(-1).propertyNames(new String[] { "axis" }).build();
map.put("axis", axisMapping);
ret.put(tensorflowName(), map);
ret.put(onnxName(), map);
return ret;
}
use of org.nd4j.imports.descriptors.properties.PropertyMapping in project nd4j by deeplearning4j.
the class StridedSlice method mappingsForFunction.
@Override
public Map<String, Map<String, PropertyMapping>> mappingsForFunction() {
Map<String, Map<String, PropertyMapping>> ret = new HashMap<>();
Map<String, PropertyMapping> map = new HashMap<>();
val beginMapping = PropertyMapping.builder().tfInputPosition(1).propertyNames(new String[] { "begin" }).build();
val end = PropertyMapping.builder().tfInputPosition(2).propertyNames(new String[] { "end" }).build();
val strides = PropertyMapping.builder().tfInputPosition(3).propertyNames(new String[] { "strides" }).build();
val beginMask = PropertyMapping.builder().tfAttrName("begin_mask").propertyNames(new String[] { "beginMask" }).build();
val ellipsisMask = PropertyMapping.builder().tfAttrName("ellipsis_mask").propertyNames(new String[] { "ellipsisMask" }).build();
val endMask = PropertyMapping.builder().tfAttrName("end_mask").propertyNames(new String[] { "endMask" }).build();
val newAxisMask = PropertyMapping.builder().tfAttrName("new_axis_mask").propertyNames(new String[] { "newAxisMask" }).build();
val shrinkAxisMask = PropertyMapping.builder().tfAttrName("shrink_axis_mask").propertyNames(new String[] { "shrinkAxisMask" }).build();
map.put("begin", beginMapping);
map.put("end", end);
map.put("strides", strides);
map.put("beginMask", beginMask);
map.put("ellipsisMask", ellipsisMask);
map.put("endMask", endMask);
map.put("newAxisMask", newAxisMask);
map.put("shrinkAxisMask", shrinkAxisMask);
ret.put(tensorflowName(), map);
return ret;
}
use of org.nd4j.imports.descriptors.properties.PropertyMapping in project nd4j by deeplearning4j.
the class Transpose method mappingsForFunction.
@Override
public Map<String, Map<String, PropertyMapping>> mappingsForFunction() {
Map<String, Map<String, PropertyMapping>> ret = new LinkedHashMap<>();
Map<String, PropertyMapping> map = new LinkedHashMap<>();
val mapping = PropertyMapping.builder().onnxAttrName("perm").propertyNames(new String[] { "permuteDims" }).tfInputPosition(1).build();
map.put("permuteDims", mapping);
ret.put(tensorflowName(), map);
ret.put(onnxName(), map);
return ret;
}
use of org.nd4j.imports.descriptors.properties.PropertyMapping in project nd4j by deeplearning4j.
the class Conv2D method mappingsForFunction.
@Override
public Map<String, Map<String, PropertyMapping>> mappingsForFunction() {
Map<String, Map<String, PropertyMapping>> ret = new HashMap<>();
Map<String, PropertyMapping> map = new HashMap<>();
val strideMapping = PropertyMapping.builder().tfAttrName("strides").onnxAttrName("strides").propertyNames(new String[] { "sx", "sy" }).build();
val kernelMappingH = PropertyMapping.builder().propertyNames(new String[] { "kh" }).tfInputPosition(1).shapePosition(0).onnxAttrName("kernel_shape").build();
val kernelMappingW = PropertyMapping.builder().propertyNames(new String[] { "kw" }).tfInputPosition(1).shapePosition(1).onnxAttrName("kernel_shape").build();
val dilationMapping = PropertyMapping.builder().onnxAttrName("dilations").propertyNames(new String[] { "dw", "dh" }).tfAttrName("rates").build();
val dataFormat = PropertyMapping.builder().onnxAttrName("data_format").tfAttrName("data_format").propertyNames(new String[] { "dataFormat" }).build();
val nhwc = PropertyMapping.builder().onnxAttrName("data_format").tfAttrName("data_format").propertyNames(new String[] { "isNHWC" }).build();
val sameMode = PropertyMapping.builder().onnxAttrName("auto_pad").propertyNames(new String[] { "isSameMode" }).tfAttrName("padding").build();
val paddingWidthHeight = PropertyMapping.builder().onnxAttrName("padding").propertyNames(new String[] { "ph", "pw" }).build();
map.put("sx", strideMapping);
map.put("sy", strideMapping);
map.put("kh", kernelMappingH);
map.put("kw", kernelMappingW);
map.put("dw", dilationMapping);
map.put("dh", dilationMapping);
map.put("isSameMode", sameMode);
map.put("ph", paddingWidthHeight);
map.put("pw", paddingWidthHeight);
map.put("dataFormat", dataFormat);
map.put("isNHWC", nhwc);
try {
ret.put(onnxName(), map);
} catch (NoOpNameFoundException e) {
// ignore
}
try {
ret.put(tensorflowName(), map);
} catch (NoOpNameFoundException e) {
// ignore
}
return ret;
}
Aggregations