Search in sources :

Example 21 with SDVariable

use of org.nd4j.autodiff.samediff.SDVariable in project nd4j by deeplearning4j.

the class LossFunctions method nonZeroCount.

/**
 * Determine the number of weight entries that are non-zero, after broadcasting
 *
 * @param weights
 * @param labels
 * @return
 */
private static SDVariable nonZeroCount(SDVariable weights, SDVariable labels) {
    SameDiff sd = weights.getSameDiff();
    SDVariable present = sd.neq(weights, 0.0);
    SDVariable presentBroadcast = sd.zerosLike(labels).add(present);
    return sd.sum(presentBroadcast);
}
Also used : SDVariable(org.nd4j.autodiff.samediff.SDVariable) SameDiff(org.nd4j.autodiff.samediff.SameDiff)

Example 22 with SDVariable

use of org.nd4j.autodiff.samediff.SDVariable in project nd4j by deeplearning4j.

the class LossFunctions method doReduce.

/**
 * Perform the final reduction on the loss function
 *
 * @param sd
 * @param outputName
 * @param isMean
 * @param b
 * @param reduction
 * @param preReduceLoss
 * @param label
 * @param weights
 * @param dimensions
 * @return
 */
private static LossInfo doReduce(SameDiff sd, String outputName, boolean isMean, LossInfo.Builder b, Reduction reduction, SDVariable preReduceLoss, SDVariable label, SDVariable weights, int[] dimensions) {
    switch(reduction) {
        case NONE:
            // Return same shape as predictions/labels
            b.loss(preReduceLoss);
            break;
        case SPECIFIED_DIMS:
            // Reduce along specified dimensions
            if (isMean) {
                // Example: MSE + mean along examples
                b.loss(sd.mean(outputName, preReduceLoss, dimensions));
            } else {
                // Example: L1 loss (sum) + mean along examples
                b.loss(sd.sum(outputName, preReduceLoss, dimensions));
            }
        case SUM:
            if (isMean) {
                // Example: MSE (mean) + sum along examples
                SDVariable m = sd.mean(preReduceLoss, dimensions);
                b.loss(sd.sum(outputName, m));
            } else {
                // Example: L1 loss (sum) + sum along examples -> sum along all dimensions
                b.loss(sd.sum(outputName, preReduceLoss));
            }
            break;
        case MEAN_BY_WEIGHT:
            SDVariable weightSum = sd.sum(weights);
            if (isMean) {
                // Example: MSE (mean) + mean by weights over examples
                // reduce along dims + reduce along remaining dims == reduce along *all* dims
                SDVariable m2 = sd.mean(preReduceLoss);
                b.loss(m2.div(outputName, weightSum));
            } else {
                // Example: L1 (sum) + mean by weights over examples
                SDVariable sum = sd.sum(preReduceLoss, dimensions);
                b.loss(sum.div(outputName, weightSum));
            }
            break;
        case MEAN_BY_COUNT:
            SDVariable nonZeroWeights = nonZeroCount(weights, label);
            SDVariable r;
            if (isMean) {
                // Example: MSE (mean) + mean by count over examples
                r = sd.sum(preReduceLoss);
            } else {
                // Example: L1 (sum) + mean by count over examples
                SDVariable sum = sd.sum(preReduceLoss, dimensions);
                r = sd.mean(sum);
            }
            b.loss(r.div(outputName, nonZeroWeights));
            break;
        default:
            throw new RuntimeException("Unknown reduction: " + reduction);
    }
    return b.build();
}
Also used : SDVariable(org.nd4j.autodiff.samediff.SDVariable)

Example 23 with SDVariable

use of org.nd4j.autodiff.samediff.SDVariable in project nd4j by deeplearning4j.

the class Diag method doDiff.

@Override
public List<SDVariable> doDiff(List<SDVariable> i_v) {
    SDVariable grad = i_v.get(0);
    SDVariable ret = sameDiff.diagPart(grad);
    return Arrays.asList(ret);
}
Also used : SDVariable(org.nd4j.autodiff.samediff.SDVariable)

Example 24 with SDVariable

use of org.nd4j.autodiff.samediff.SDVariable in project nd4j by deeplearning4j.

the class DynamicStitch method doDiff.

@Override
public List<SDVariable> doDiff(List<SDVariable> i_v) {
    // DynamicPartition and DynamicStitch are mutually inverse
    SDVariable gradient = i_v.get(0);
    SDVariable[] partitionData = indices.clone();
    for (int i = 0; i < indices.length; i++) partitionData[i] = sameDiff.onesLike(indices[i]).mul(i);
    SDVariable partitions = sameDiff.dynamicStitch(partitionData, indices);
    SDVariable[] ret = sameDiff.dynamicPartition(gradient, partitions, numPartitions);
    return new ArrayList<>(Arrays.asList(ret));
}
Also used : SDVariable(org.nd4j.autodiff.samediff.SDVariable)

Example 25 with SDVariable

use of org.nd4j.autodiff.samediff.SDVariable in project nd4j by deeplearning4j.

the class GradCheckTransforms method testDynamicStitch.

@Test
public void testDynamicStitch() {
    SameDiff sd = SameDiff.create();
    INDArray ia = Nd4j.create(new float[] { 5, 1, 3 }, new int[] { 1, 3 });
    INDArray ib = Nd4j.create(new float[] { 7, 2, 4 }, new int[] { 1, 3 });
    INDArray indexA = Nd4j.create(new float[] { 0, 1, 4 }, new int[] { 1, 3 });
    INDArray indexB = Nd4j.create(new float[] { 2, 3, 5 }, new int[] { 1, 3 });
    INDArray expOut = Nd4j.create(new int[] { 1, 6 });
    DynamicCustomOp dynamicStitch = DynamicCustomOp.builder("dynamic_stitch").addInputs(indexA, indexB, ia, ib).addOutputs(expOut).build();
    Nd4j.getExecutioner().exec(dynamicStitch);
    SDVariable in1 = sd.var("in1", new int[] { 1, 3 });
    SDVariable in2 = sd.var("in2", new int[] { 1, 3 });
    SDVariable index1 = sd.var("index1", new int[] { 1, 3 });
    SDVariable index2 = sd.var("index2", new int[] { 1, 3 });
    sd.associateArrayWithVariable(ia, in1);
    sd.associateArrayWithVariable(ib, in2);
    sd.associateArrayWithVariable(indexA, index1);
    sd.associateArrayWithVariable(indexB, index2);
    SDVariable t = sd.dynamicStitch(new SDVariable[] { index1, index2 }, new SDVariable[] { in1, in2 });
    SDVariable loss = sd.mean("loss", t);
    sd.exec();
    INDArray out = t.getArr();
    if (!expOut.equals(out)) {
        log.error("forward failed");
    }
}
Also used : SDVariable(org.nd4j.autodiff.samediff.SDVariable) INDArray(org.nd4j.linalg.api.ndarray.INDArray) SameDiff(org.nd4j.autodiff.samediff.SameDiff) DynamicCustomOp(org.nd4j.linalg.api.ops.DynamicCustomOp) Test(org.junit.Test)

Aggregations

SDVariable (org.nd4j.autodiff.samediff.SDVariable)104 SameDiff (org.nd4j.autodiff.samediff.SameDiff)41 INDArray (org.nd4j.linalg.api.ndarray.INDArray)38 Test (org.junit.Test)36 ArrayList (java.util.ArrayList)18 DynamicCustomOp (org.nd4j.linalg.api.ops.DynamicCustomOp)10 lombok.val (lombok.val)7 LossFunctions (org.nd4j.autodiff.loss.LossFunctions)4 LossInfo (org.nd4j.autodiff.loss.LossInfo)4 BernoulliDistribution (org.nd4j.linalg.api.ops.random.impl.BernoulliDistribution)4 Ignore (org.junit.Ignore)3 DifferentialFunction (org.nd4j.autodiff.functions.DifferentialFunction)3 ND4JIllegalStateException (org.nd4j.linalg.exception.ND4JIllegalStateException)3 Triple (org.nd4j.linalg.primitives.Triple)2 DataOutputStream (java.io.DataOutputStream)1 FileOutputStream (java.io.FileOutputStream)1 ByteBuffer (java.nio.ByteBuffer)1 NoOpNameFoundException (org.nd4j.imports.NoOpNameFoundException)1 NdIndexIterator (org.nd4j.linalg.api.iter.NdIndexIterator)1 TruncateDivOp (org.nd4j.linalg.api.ops.impl.transforms.arithmetic.TruncateDivOp)1