use of org.nd4j.autodiff.samediff.SameDiff in project nd4j by deeplearning4j.
the class LossFunctions method mse.
/**
* Mean squared error: L = mean( (predicted - label)^2)
*
* @param outputName Name of the output SDVariable
* @param predictions Predictions variable
* @param label Label variable
* @param weights Weights array. May be null, or any broadcastable shape (with predictions/label arrays).
* Note that this is also used for masking (weight of 0 = 'masked out')
* @param reduction Type of reduction to perform for the loss function
* @param dimensions Dimension(s) to apply the loss function on
* @return LossInfo - bean with variables etc for the loss function
*/
public static LossInfo mse(String outputName, SDVariable predictions, SDVariable label, SDVariable weights, Reduction reduction, int... dimensions) {
LossInfo.Builder b = validate("mse", predictions, label, reduction);
SameDiff sd = predictions.getSameDiff();
if (weights == null) {
weights = sd.one("mse_loss_weights", SCALAR);
}
SDVariable diff = predictions.sub(label);
String name = (reduction == Reduction.NONE ? outputName : null);
SDVariable preReduceLoss = sd.square(diff).mul(name, weights);
return doReduce(sd, outputName, true, b, reduction, preReduceLoss, label, weights, dimensions);
}
use of org.nd4j.autodiff.samediff.SameDiff in project nd4j by deeplearning4j.
the class LossFunctions method l2.
/**
* L2 loss function: i.e., sum of squared errors, L = sum_i (actual_i - predicted)^2
*
* @param outputName
* @param predictions
* @param label
* @param weights
* @param reduction
* @param dimensions
* @return
*/
public static LossInfo l2(String outputName, SDVariable predictions, SDVariable label, SDVariable weights, Reduction reduction, int... dimensions) {
LossInfo.Builder b = validate("l2", predictions, label, reduction);
SameDiff sd = predictions.getSameDiff();
if (weights == null) {
weights = sd.one("l2_loss_weights", SCALAR);
}
SDVariable diff = predictions.sub(label);
String name = (reduction == Reduction.NONE ? outputName : null);
SDVariable preReduceLoss = sd.square(diff).mul(name, weights);
return doReduce(sd, outputName, false, b, reduction, preReduceLoss, label, weights, dimensions);
}
use of org.nd4j.autodiff.samediff.SameDiff in project nd4j by deeplearning4j.
the class LossFunctions method mcxent.
/**
* Multi-Class Cross Entropy loss function:<br>
* L = sum_i actual_i * log( predicted_i )
*
* @param outputName
* @param predictions
* @param label
* @param weights
* @param reduction
* @param dimensions
* @return
*/
public static LossInfo mcxent(String outputName, SDVariable predictions, SDVariable label, SDVariable weights, Reduction reduction, int... dimensions) {
LossInfo.Builder b = validate("mcxent", predictions, label, reduction);
SameDiff sd = predictions.getSameDiff();
if (weights == null) {
weights = sd.one("mcxent_loss_weights", SCALAR);
}
String name = (reduction == Reduction.NONE ? outputName : null);
SDVariable weightedLogProd = sd.log(predictions).mul(label).mul(name, weights);
return doReduce(sd, outputName, false, b, reduction, weightedLogProd, label, weights, dimensions);
}
use of org.nd4j.autodiff.samediff.SameDiff in project nd4j by deeplearning4j.
the class LossFunctions method l1.
/**
* L1 loss - sum of absolute errors. L = sum_i abs(predicted_i - actual_i)
*
* @param outputName
* @param predictions
* @param label
* @param weights
* @param reduction
* @param dimensions
* @return
*/
public static LossInfo l1(String outputName, SDVariable predictions, SDVariable label, SDVariable weights, Reduction reduction, int... dimensions) {
LossInfo.Builder b = validate("l1", predictions, label, reduction);
SameDiff sd = predictions.getSameDiff();
if (weights == null) {
weights = sd.one("l1_loss_weights", SCALAR);
}
String name = (reduction == Reduction.NONE ? outputName : null);
SDVariable preReduceLoss = sd.abs(predictions.sub(label)).mul(name, weights);
return doReduce(sd, outputName, false, b, reduction, preReduceLoss, label, weights, dimensions);
}
use of org.nd4j.autodiff.samediff.SameDiff in project nd4j by deeplearning4j.
the class BaseGraphMapper method importGraph.
/**
* This method converts given TF
* @param tfGraph
* @return
*/
@Override
public SameDiff importGraph(GRAPH_TYPE tfGraph) {
SameDiff diff = SameDiff.create();
ImportState<GRAPH_TYPE, TENSOR_TYPE> importState = new ImportState<>();
importState.setSameDiff(diff);
importState.setGraph(tfGraph);
val variablesForGraph = variablesForGraph(tfGraph);
importState.setVariables(variablesForGraph);
// for each variable
for (Map.Entry<String, TENSOR_TYPE> entry : variablesForGraph.entrySet()) {
if (dataTypeForTensor(entry.getValue()) == DataBuffer.Type.UNKNOWN) {
val var = importState.getSameDiff().var(entry.getKey(), null, new ZeroInitScheme('c'));
// mark as place holder for validating resolution later.
if (isPlaceHolder(entry.getValue())) {
importState.getSameDiff().addAsPlaceHolder(var.getVarName());
if (var.getShape() != null)
importState.getSameDiff().setOriginalPlaceHolderShape(var.getVarName(), var.getShape());
}
continue;
}
val arr = getNDArrayFromTensor(entry.getKey(), entry.getValue(), tfGraph);
if (arr != null) {
val var = importState.getSameDiff().var(entry.getKey(), arr);
// ensure the array is made available for later processing
diff.associateArrayWithVariable(arr, var);
} else if (getShapeFromTensor(entry.getValue()) == null) {
val var = importState.getSameDiff().var(entry.getKey(), null, new ZeroInitScheme('c'));
// that it isn't a place holder.
if (isPlaceHolder(entry.getValue())) {
val originalShape = getShapeFromTensor(entry.getValue());
importState.getSameDiff().addAsPlaceHolder(var.getVarName());
if (var.getShape() != null)
importState.getSameDiff().setOriginalPlaceHolderShape(var.getVarName(), originalShape);
}
} else {
val originalShape = getShapeFromTensor(entry.getValue());
val var = importState.getSameDiff().var(entry.getKey(), originalShape);
// that it isn't a place holder.
if (isPlaceHolder(entry.getValue())) {
importState.getSameDiff().addAsPlaceHolder(var.getVarName());
importState.getSameDiff().setOriginalPlaceHolderShape(var.getVarName(), originalShape);
}
}
}
// setup vertex ids for names
// handle mapping vertex ids properly
val tfNodesList = getNodeList(tfGraph);
for (NODE_TYPE tfNode : tfNodesList) {
if (!opsToIgnore().contains(getOpType(tfNode)) || isOpIgnoreException(tfNode))
mapNodeType(tfNode, importState);
}
return diff;
}
Aggregations