use of org.nd4j.imports.descriptors.properties.PropertyMapping in project nd4j by deeplearning4j.
the class Unstack method mappingsForFunction.
@Override
public Map<String, Map<String, PropertyMapping>> mappingsForFunction() {
Map<String, Map<String, PropertyMapping>> ret = new HashMap<>();
Map<String, PropertyMapping> map = new HashMap<>();
val axisMapping = PropertyMapping.builder().onnxAttrName("axis").tfInputPosition(-1).propertyNames(new String[] { "axis" }).build();
map.put("axis", axisMapping);
ret.put(tensorflowName(), map);
return ret;
}
use of org.nd4j.imports.descriptors.properties.PropertyMapping in project nd4j by deeplearning4j.
the class DynamicPartition method mappingsForFunction.
@Override
public Map<String, Map<String, PropertyMapping>> mappingsForFunction() {
Map<String, Map<String, PropertyMapping>> ret = new HashMap<>();
Map<String, PropertyMapping> attrs = new LinkedHashMap<>();
val numPartitions = PropertyMapping.builder().tfAttrName("num_partitions").propertyNames(new String[] { "numPartitions" }).build();
attrs.put("numPartitions", numPartitions);
ret.put(tensorflowName(), attrs);
return ret;
}
use of org.nd4j.imports.descriptors.properties.PropertyMapping in project nd4j by deeplearning4j.
the class ReverseSequence method mappingsForFunction.
@Override
public Map<String, Map<String, PropertyMapping>> mappingsForFunction() {
Map<String, Map<String, PropertyMapping>> ret = new HashMap<>();
Map<String, PropertyMapping> attrs = new LinkedHashMap<>();
val seqDim = PropertyMapping.builder().propertyNames(new String[] { "seqDim" }).tfAttrName("seq_dim").build();
val batchDim = PropertyMapping.builder().propertyNames(new String[] { "batchDim" }).tfAttrName("batch_dim").build();
attrs.put("seqDim", seqDim);
attrs.put("batchDim", batchDim);
ret.put(tensorflowName(), attrs);
return ret;
}
use of org.nd4j.imports.descriptors.properties.PropertyMapping in project nd4j by deeplearning4j.
the class Cast method mappingsForFunction.
@Override
public Map<String, Map<String, PropertyMapping>> mappingsForFunction() {
Map<String, Map<String, PropertyMapping>> ret = new HashMap<>();
Map<String, PropertyMapping> map = new HashMap<>();
val dstMapping = PropertyMapping.builder().tfAttrName("DstT").propertyNames(new String[] { "typeDst" }).build();
for (val propertyMapping : new PropertyMapping[] { dstMapping }) {
for (val keys : propertyMapping.getPropertyNames()) map.put(keys, propertyMapping);
}
ret.put(tensorflowName(), map);
return ret;
}
use of org.nd4j.imports.descriptors.properties.PropertyMapping in project nd4j by deeplearning4j.
the class Conv3D method mappingsForFunction.
@Override
public Map<String, Map<String, PropertyMapping>> mappingsForFunction() {
Map<String, Map<String, PropertyMapping>> ret = new HashMap<>();
Map<String, PropertyMapping> map = new HashMap<>();
val kernelMapping = PropertyMapping.builder().propertyNames(new String[] { "kT", "kW", "kH" }).tfInputPosition(1).onnxAttrName("kernel_shape").build();
val strideMapping = PropertyMapping.builder().tfAttrName("strides").onnxAttrName("strides").propertyNames(new String[] { "dT", "dW", "dH" }).build();
val dilationMapping = PropertyMapping.builder().onnxAttrName("dilations").propertyNames(new String[] { "dilationT", "dilationH", "dilationW" }).tfAttrName("rates").build();
val sameMode = PropertyMapping.builder().onnxAttrName("auto_pad").propertyNames(new String[] { "isValidMode" }).tfAttrName("padding").build();
val paddingWidthHeight = PropertyMapping.builder().onnxAttrName("padding").propertyNames(new String[] { "pT", "pW", "pH" }).build();
val dataFormat = PropertyMapping.builder().onnxAttrName("data_format").tfAttrName("data_format").propertyNames(new String[] { "dataFormat" }).build();
val outputPadding = PropertyMapping.builder().propertyNames(new String[] { "aT", "aH", "aW" }).build();
val biasUsed = PropertyMapping.builder().propertyNames(new String[] { "biasUsed" }).build();
for (val propertyMapping : new PropertyMapping[] { kernelMapping, strideMapping, dilationMapping, sameMode, paddingWidthHeight, dataFormat, outputPadding, biasUsed }) {
for (val keys : propertyMapping.getPropertyNames()) map.put(keys, propertyMapping);
}
ret.put(onnxName(), map);
ret.put(tensorflowName(), map);
return ret;
}
Aggregations