Search in sources :

Example 11 with AttributeAdapter

use of org.nd4j.imports.descriptors.properties.AttributeAdapter in project nd4j by deeplearning4j.

the class TFGraphMapper method initFunctionFromProperties.

/**
 * Init a function's attributes
 * @param mappedTfName the tensorflow name to pick (sometimes ops have multiple names
 * @param on the function to map
 * @param attributesForNode the attributes for the node
 * @param node
 * @param graph
 */
public void initFunctionFromProperties(String mappedTfName, DifferentialFunction on, Map<String, AttrValue> attributesForNode, NodeDef node, GraphDef graph) {
    val properties = on.mappingsForFunction();
    val tfProperties = properties.get(mappedTfName);
    val fields = DifferentialFunctionClassHolder.getInstance().getFieldsForFunction(on);
    val attributeAdapters = on.attributeAdaptersForFunction();
    // if there's no properties announced for this function - just return
    if (tfProperties == null)
        return;
    for (val entry : tfProperties.entrySet()) {
        val tfAttrName = entry.getValue().getTfAttrName();
        val currentField = fields.get(entry.getKey());
        AttributeAdapter adapter = null;
        if (attributeAdapters != null && !attributeAdapters.isEmpty()) {
            val mappers = attributeAdapters.get(mappedTfName);
            val adapterFor = mappers.get(entry.getKey());
            adapter = adapterFor;
        }
        if (tfAttrName != null) {
            if (currentField == null) {
                continue;
            }
            if (attributesForNode.containsKey(tfAttrName)) {
                val attr = attributesForNode.get(tfAttrName);
                switch(attr.getValueCase()) {
                    case B:
                        if (adapter != null) {
                            adapter.mapAttributeFor(attr.getB(), currentField, on);
                        }
                        break;
                    case F:
                        break;
                    case FUNC:
                        break;
                    case S:
                        val setString = attr.getS().toStringUtf8();
                        if (adapter != null) {
                            adapter.mapAttributeFor(setString, currentField, on);
                        } else
                            on.setValueFor(currentField, setString);
                        break;
                    case I:
                        val setInt = (int) attr.getI();
                        if (adapter != null) {
                            adapter.mapAttributeFor(setInt, currentField, on);
                        } else
                            on.setValueFor(currentField, setInt);
                        break;
                    case SHAPE:
                        val shape = attr.getShape().getDimList();
                        int[] dimsToSet = new int[shape.size()];
                        for (int i = 0; i < dimsToSet.length; i++) {
                            dimsToSet[i] = (int) shape.get(i).getSize();
                        }
                        if (adapter != null) {
                            adapter.mapAttributeFor(dimsToSet, currentField, on);
                        } else
                            on.setValueFor(currentField, dimsToSet);
                        break;
                    case VALUE_NOT_SET:
                        break;
                    case PLACEHOLDER:
                        break;
                    case LIST:
                        val setList = attr.getList();
                        if (!setList.getIList().isEmpty()) {
                            val intList = Ints.toArray(setList.getIList());
                            if (adapter != null) {
                                adapter.mapAttributeFor(intList, currentField, on);
                            } else
                                on.setValueFor(currentField, intList);
                        } else if (!setList.getBList().isEmpty()) {
                            break;
                        } else if (!setList.getFList().isEmpty()) {
                            val floats = Floats.toArray(setList.getFList());
                            if (adapter != null) {
                                adapter.mapAttributeFor(floats, currentField, on);
                            } else
                                on.setValueFor(currentField, floats);
                            break;
                        } else if (!setList.getFuncList().isEmpty()) {
                            break;
                        } else if (!setList.getTensorList().isEmpty()) {
                            break;
                        }
                        break;
                    case TENSOR:
                        val tensorToGet = TFGraphMapper.getInstance().mapTensorProto(attr.getTensor());
                        if (adapter != null) {
                            adapter.mapAttributeFor(tensorToGet, currentField, on);
                        } else
                            on.setValueFor(currentField, tensorToGet);
                        break;
                    case TYPE:
                        if (adapter != null) {
                            adapter.mapAttributeFor(attr.getType(), currentField, on);
                        }
                        break;
                }
            }
        } else if (entry.getValue().getTfInputPosition() != null) {
            int position = entry.getValue().getTfInputPosition();
            if (position < 0) {
                position += node.getInputCount();
            }
            val inputFromNode = TFGraphMapper.getInstance().getNodeWithNameFromGraph(graph, node.getInput(position));
            INDArray tensor = inputFromNode != null ? TFGraphMapper.getInstance().getNDArrayFromTensor("value", inputFromNode, graph) : null;
            if (tensor == null) {
                tensor = on.getSameDiff().getArrForVarName(getNodeName(node.getInput(position)));
            }
            if (tensor != null) {
                // use adapter instead of direct mapping just like above
                if (adapter != null) {
                    adapter.mapAttributeFor(tensor, currentField, on);
                } else {
                    if (currentField.getType().equals(int[].class)) {
                        on.setValueFor(currentField, tensor.data().asInt());
                    } else if (currentField.getType().equals(double[].class)) {
                        on.setValueFor(currentField, tensor.data().asDouble());
                    } else if (currentField.getType().equals(float[].class)) {
                        on.setValueFor(currentField, tensor.data().asFloat());
                    } else if (currentField.getType().equals(INDArray.class)) {
                        on.setValueFor(currentField, tensor);
                    } else if (currentField.getType().equals(int.class)) {
                        on.setValueFor(currentField, tensor.getInt(0));
                    } else if (currentField.getType().equals(double.class)) {
                        on.setValueFor(currentField, tensor.getDouble(0));
                    } else if (currentField.getType().equals(float.class)) {
                        on.setValueFor(currentField, tensor.getFloat(0));
                    }
                }
            } else {
                on.getSameDiff().addPropertyToResolve(on, entry.getKey());
            }
        }
    }
}
Also used : lombok.val(lombok.val) INDArray(org.nd4j.linalg.api.ndarray.INDArray) AttributeAdapter(org.nd4j.imports.descriptors.properties.AttributeAdapter)

Aggregations

AttributeAdapter (org.nd4j.imports.descriptors.properties.AttributeAdapter)11 lombok.val (lombok.val)8 ConditionalFieldValueNDArrayShapeAdapter (org.nd4j.imports.descriptors.properties.adapters.ConditionalFieldValueNDArrayShapeAdapter)4 StringEqualsAdapter (org.nd4j.imports.descriptors.properties.adapters.StringEqualsAdapter)4 ConditionalFieldValueIntIndexArrayAdapter (org.nd4j.imports.descriptors.properties.adapters.ConditionalFieldValueIntIndexArrayAdapter)3 SizeThresholdIntArrayIntIndexAdpater (org.nd4j.imports.descriptors.properties.adapters.SizeThresholdIntArrayIntIndexAdpater)3 BooleanAdapter (org.nd4j.imports.descriptors.properties.adapters.BooleanAdapter)2 IntArrayIntIndexAdpater (org.nd4j.imports.descriptors.properties.adapters.IntArrayIntIndexAdpater)2 HashMap (java.util.HashMap)1 LinkedHashMap (java.util.LinkedHashMap)1 Map (java.util.Map)1 NoOpNameFoundException (org.nd4j.imports.NoOpNameFoundException)1 DataTypeAdapter (org.nd4j.imports.descriptors.properties.adapters.DataTypeAdapter)1 INDArray (org.nd4j.linalg.api.ndarray.INDArray)1