use of org.nd4j.imports.descriptors.properties.AttributeAdapter in project nd4j by deeplearning4j.
the class TFGraphMapper method initFunctionFromProperties.
/**
* Init a function's attributes
* @param mappedTfName the tensorflow name to pick (sometimes ops have multiple names
* @param on the function to map
* @param attributesForNode the attributes for the node
* @param node
* @param graph
*/
public void initFunctionFromProperties(String mappedTfName, DifferentialFunction on, Map<String, AttrValue> attributesForNode, NodeDef node, GraphDef graph) {
val properties = on.mappingsForFunction();
val tfProperties = properties.get(mappedTfName);
val fields = DifferentialFunctionClassHolder.getInstance().getFieldsForFunction(on);
val attributeAdapters = on.attributeAdaptersForFunction();
// if there's no properties announced for this function - just return
if (tfProperties == null)
return;
for (val entry : tfProperties.entrySet()) {
val tfAttrName = entry.getValue().getTfAttrName();
val currentField = fields.get(entry.getKey());
AttributeAdapter adapter = null;
if (attributeAdapters != null && !attributeAdapters.isEmpty()) {
val mappers = attributeAdapters.get(mappedTfName);
val adapterFor = mappers.get(entry.getKey());
adapter = adapterFor;
}
if (tfAttrName != null) {
if (currentField == null) {
continue;
}
if (attributesForNode.containsKey(tfAttrName)) {
val attr = attributesForNode.get(tfAttrName);
switch(attr.getValueCase()) {
case B:
if (adapter != null) {
adapter.mapAttributeFor(attr.getB(), currentField, on);
}
break;
case F:
break;
case FUNC:
break;
case S:
val setString = attr.getS().toStringUtf8();
if (adapter != null) {
adapter.mapAttributeFor(setString, currentField, on);
} else
on.setValueFor(currentField, setString);
break;
case I:
val setInt = (int) attr.getI();
if (adapter != null) {
adapter.mapAttributeFor(setInt, currentField, on);
} else
on.setValueFor(currentField, setInt);
break;
case SHAPE:
val shape = attr.getShape().getDimList();
int[] dimsToSet = new int[shape.size()];
for (int i = 0; i < dimsToSet.length; i++) {
dimsToSet[i] = (int) shape.get(i).getSize();
}
if (adapter != null) {
adapter.mapAttributeFor(dimsToSet, currentField, on);
} else
on.setValueFor(currentField, dimsToSet);
break;
case VALUE_NOT_SET:
break;
case PLACEHOLDER:
break;
case LIST:
val setList = attr.getList();
if (!setList.getIList().isEmpty()) {
val intList = Ints.toArray(setList.getIList());
if (adapter != null) {
adapter.mapAttributeFor(intList, currentField, on);
} else
on.setValueFor(currentField, intList);
} else if (!setList.getBList().isEmpty()) {
break;
} else if (!setList.getFList().isEmpty()) {
val floats = Floats.toArray(setList.getFList());
if (adapter != null) {
adapter.mapAttributeFor(floats, currentField, on);
} else
on.setValueFor(currentField, floats);
break;
} else if (!setList.getFuncList().isEmpty()) {
break;
} else if (!setList.getTensorList().isEmpty()) {
break;
}
break;
case TENSOR:
val tensorToGet = TFGraphMapper.getInstance().mapTensorProto(attr.getTensor());
if (adapter != null) {
adapter.mapAttributeFor(tensorToGet, currentField, on);
} else
on.setValueFor(currentField, tensorToGet);
break;
case TYPE:
if (adapter != null) {
adapter.mapAttributeFor(attr.getType(), currentField, on);
}
break;
}
}
} else if (entry.getValue().getTfInputPosition() != null) {
int position = entry.getValue().getTfInputPosition();
if (position < 0) {
position += node.getInputCount();
}
val inputFromNode = TFGraphMapper.getInstance().getNodeWithNameFromGraph(graph, node.getInput(position));
INDArray tensor = inputFromNode != null ? TFGraphMapper.getInstance().getNDArrayFromTensor("value", inputFromNode, graph) : null;
if (tensor == null) {
tensor = on.getSameDiff().getArrForVarName(getNodeName(node.getInput(position)));
}
if (tensor != null) {
// use adapter instead of direct mapping just like above
if (adapter != null) {
adapter.mapAttributeFor(tensor, currentField, on);
} else {
if (currentField.getType().equals(int[].class)) {
on.setValueFor(currentField, tensor.data().asInt());
} else if (currentField.getType().equals(double[].class)) {
on.setValueFor(currentField, tensor.data().asDouble());
} else if (currentField.getType().equals(float[].class)) {
on.setValueFor(currentField, tensor.data().asFloat());
} else if (currentField.getType().equals(INDArray.class)) {
on.setValueFor(currentField, tensor);
} else if (currentField.getType().equals(int.class)) {
on.setValueFor(currentField, tensor.getInt(0));
} else if (currentField.getType().equals(double.class)) {
on.setValueFor(currentField, tensor.getDouble(0));
} else if (currentField.getType().equals(float.class)) {
on.setValueFor(currentField, tensor.getFloat(0));
}
}
} else {
on.getSameDiff().addPropertyToResolve(on, entry.getKey());
}
}
}
}
Aggregations