Search in sources :

Example 46 with SameDiff

use of org.nd4j.autodiff.samediff.SameDiff in project nd4j by deeplearning4j.

the class LossFunctions method mse.

/**
 * Mean squared error: L = mean( (predicted - label)^2)
 *
 * @param outputName  Name of the output SDVariable
 * @param predictions Predictions variable
 * @param label       Label variable
 * @param weights     Weights array. May be null, or any broadcastable shape (with predictions/label arrays).
 *                    Note that this is also used for masking (weight of 0 = 'masked out')
 * @param reduction   Type of reduction to perform for the loss function
 * @param dimensions  Dimension(s) to apply the loss function on
 * @return LossInfo - bean with variables etc for the loss function
 */
public static LossInfo mse(String outputName, SDVariable predictions, SDVariable label, SDVariable weights, Reduction reduction, int... dimensions) {
    LossInfo.Builder b = validate("mse", predictions, label, reduction);
    SameDiff sd = predictions.getSameDiff();
    if (weights == null) {
        weights = sd.one("mse_loss_weights", SCALAR);
    }
    SDVariable diff = predictions.sub(label);
    String name = (reduction == Reduction.NONE ? outputName : null);
    SDVariable preReduceLoss = sd.square(diff).mul(name, weights);
    return doReduce(sd, outputName, true, b, reduction, preReduceLoss, label, weights, dimensions);
}
Also used : SDVariable(org.nd4j.autodiff.samediff.SDVariable) SameDiff(org.nd4j.autodiff.samediff.SameDiff)

Example 47 with SameDiff

use of org.nd4j.autodiff.samediff.SameDiff in project nd4j by deeplearning4j.

the class LossFunctions method l2.

/**
 * L2 loss function: i.e., sum of squared errors, L = sum_i (actual_i - predicted)^2
 *
 * @param outputName
 * @param predictions
 * @param label
 * @param weights
 * @param reduction
 * @param dimensions
 * @return
 */
public static LossInfo l2(String outputName, SDVariable predictions, SDVariable label, SDVariable weights, Reduction reduction, int... dimensions) {
    LossInfo.Builder b = validate("l2", predictions, label, reduction);
    SameDiff sd = predictions.getSameDiff();
    if (weights == null) {
        weights = sd.one("l2_loss_weights", SCALAR);
    }
    SDVariable diff = predictions.sub(label);
    String name = (reduction == Reduction.NONE ? outputName : null);
    SDVariable preReduceLoss = sd.square(diff).mul(name, weights);
    return doReduce(sd, outputName, false, b, reduction, preReduceLoss, label, weights, dimensions);
}
Also used : SDVariable(org.nd4j.autodiff.samediff.SDVariable) SameDiff(org.nd4j.autodiff.samediff.SameDiff)

Example 48 with SameDiff

use of org.nd4j.autodiff.samediff.SameDiff in project nd4j by deeplearning4j.

the class LossFunctions method mcxent.

/**
 * Multi-Class Cross Entropy loss function:<br>
 * L = sum_i actual_i * log( predicted_i )
 *
 * @param outputName
 * @param predictions
 * @param label
 * @param weights
 * @param reduction
 * @param dimensions
 * @return
 */
public static LossInfo mcxent(String outputName, SDVariable predictions, SDVariable label, SDVariable weights, Reduction reduction, int... dimensions) {
    LossInfo.Builder b = validate("mcxent", predictions, label, reduction);
    SameDiff sd = predictions.getSameDiff();
    if (weights == null) {
        weights = sd.one("mcxent_loss_weights", SCALAR);
    }
    String name = (reduction == Reduction.NONE ? outputName : null);
    SDVariable weightedLogProd = sd.log(predictions).mul(label).mul(name, weights);
    return doReduce(sd, outputName, false, b, reduction, weightedLogProd, label, weights, dimensions);
}
Also used : SDVariable(org.nd4j.autodiff.samediff.SDVariable) SameDiff(org.nd4j.autodiff.samediff.SameDiff)

Example 49 with SameDiff

use of org.nd4j.autodiff.samediff.SameDiff in project nd4j by deeplearning4j.

the class LossFunctions method l1.

/**
 * L1 loss - sum of absolute errors. L = sum_i abs(predicted_i - actual_i)
 *
 * @param outputName
 * @param predictions
 * @param label
 * @param weights
 * @param reduction
 * @param dimensions
 * @return
 */
public static LossInfo l1(String outputName, SDVariable predictions, SDVariable label, SDVariable weights, Reduction reduction, int... dimensions) {
    LossInfo.Builder b = validate("l1", predictions, label, reduction);
    SameDiff sd = predictions.getSameDiff();
    if (weights == null) {
        weights = sd.one("l1_loss_weights", SCALAR);
    }
    String name = (reduction == Reduction.NONE ? outputName : null);
    SDVariable preReduceLoss = sd.abs(predictions.sub(label)).mul(name, weights);
    return doReduce(sd, outputName, false, b, reduction, preReduceLoss, label, weights, dimensions);
}
Also used : SDVariable(org.nd4j.autodiff.samediff.SDVariable) SameDiff(org.nd4j.autodiff.samediff.SameDiff)

Example 50 with SameDiff

use of org.nd4j.autodiff.samediff.SameDiff in project nd4j by deeplearning4j.

the class BaseGraphMapper method importGraph.

/**
 * This method converts given TF
 * @param tfGraph
 * @return
 */
@Override
public SameDiff importGraph(GRAPH_TYPE tfGraph) {
    SameDiff diff = SameDiff.create();
    ImportState<GRAPH_TYPE, TENSOR_TYPE> importState = new ImportState<>();
    importState.setSameDiff(diff);
    importState.setGraph(tfGraph);
    val variablesForGraph = variablesForGraph(tfGraph);
    importState.setVariables(variablesForGraph);
    // for each variable
    for (Map.Entry<String, TENSOR_TYPE> entry : variablesForGraph.entrySet()) {
        if (dataTypeForTensor(entry.getValue()) == DataBuffer.Type.UNKNOWN) {
            val var = importState.getSameDiff().var(entry.getKey(), null, new ZeroInitScheme('c'));
            // mark as place holder for validating resolution later.
            if (isPlaceHolder(entry.getValue())) {
                importState.getSameDiff().addAsPlaceHolder(var.getVarName());
                if (var.getShape() != null)
                    importState.getSameDiff().setOriginalPlaceHolderShape(var.getVarName(), var.getShape());
            }
            continue;
        }
        val arr = getNDArrayFromTensor(entry.getKey(), entry.getValue(), tfGraph);
        if (arr != null) {
            val var = importState.getSameDiff().var(entry.getKey(), arr);
            // ensure the array is made available for later processing
            diff.associateArrayWithVariable(arr, var);
        } else if (getShapeFromTensor(entry.getValue()) == null) {
            val var = importState.getSameDiff().var(entry.getKey(), null, new ZeroInitScheme('c'));
            // that it isn't  a place holder.
            if (isPlaceHolder(entry.getValue())) {
                val originalShape = getShapeFromTensor(entry.getValue());
                importState.getSameDiff().addAsPlaceHolder(var.getVarName());
                if (var.getShape() != null)
                    importState.getSameDiff().setOriginalPlaceHolderShape(var.getVarName(), originalShape);
            }
        } else {
            val originalShape = getShapeFromTensor(entry.getValue());
            val var = importState.getSameDiff().var(entry.getKey(), originalShape);
            // that it isn't  a place holder.
            if (isPlaceHolder(entry.getValue())) {
                importState.getSameDiff().addAsPlaceHolder(var.getVarName());
                importState.getSameDiff().setOriginalPlaceHolderShape(var.getVarName(), originalShape);
            }
        }
    }
    // setup vertex ids for  names
    // handle mapping vertex ids properly
    val tfNodesList = getNodeList(tfGraph);
    for (NODE_TYPE tfNode : tfNodesList) {
        if (!opsToIgnore().contains(getOpType(tfNode)) || isOpIgnoreException(tfNode))
            mapNodeType(tfNode, importState);
    }
    return diff;
}
Also used : lombok.val(lombok.val) ZeroInitScheme(org.nd4j.weightinit.impl.ZeroInitScheme) SameDiff(org.nd4j.autodiff.samediff.SameDiff) HashMap(java.util.HashMap) LinkedHashMap(java.util.LinkedHashMap) Map(java.util.Map)

Aggregations

SameDiff (org.nd4j.autodiff.samediff.SameDiff)50 Test (org.junit.Test)42 SDVariable (org.nd4j.autodiff.samediff.SDVariable)41 INDArray (org.nd4j.linalg.api.ndarray.INDArray)37 ArrayList (java.util.ArrayList)10 DynamicCustomOp (org.nd4j.linalg.api.ops.DynamicCustomOp)10 Ignore (org.junit.Ignore)7 ClassPathResource (org.nd4j.linalg.io.ClassPathResource)6 lombok.val (lombok.val)4 LossFunctions (org.nd4j.autodiff.loss.LossFunctions)4 LossInfo (org.nd4j.autodiff.loss.LossInfo)4 BernoulliDistribution (org.nd4j.linalg.api.ops.random.impl.BernoulliDistribution)4 DifferentialFunction (org.nd4j.autodiff.functions.DifferentialFunction)2 Triple (org.nd4j.linalg.primitives.Triple)2 ZeroInitScheme (org.nd4j.weightinit.impl.ZeroInitScheme)2 DataOutputStream (java.io.DataOutputStream)1 FileOutputStream (java.io.FileOutputStream)1 ByteBuffer (java.nio.ByteBuffer)1 HashMap (java.util.HashMap)1 LinkedHashMap (java.util.LinkedHashMap)1