Search in sources :

Example 61 with SDVariable

use of org.nd4j.autodiff.samediff.SDVariable in project nd4j by deeplearning4j.

the class SpaceToDepth method doDiff.

@Override
public List<SDVariable> doDiff(List<SDVariable> i_v) {
    // Gradient to SpaceToDepth is just DepthToSpace of same block size and data format.
    SDVariable gradient = i_v.get(0);
    SDVariable ret = sameDiff.depthToSpace(gradient, blockSize, dataFormat);
    return Arrays.asList(ret);
}
Also used : SDVariable(org.nd4j.autodiff.samediff.SDVariable)

Example 62 with SDVariable

use of org.nd4j.autodiff.samediff.SDVariable in project nd4j by deeplearning4j.

the class Sum method doDiff.

@Override
public List<SDVariable> doDiff(List<SDVariable> i_v1) {
    // Out = sum(in)
    // dL/dIn = dL/dOut * dOut/dIn
    // = dL/dOut * 1
    // But broadcast to shape of the input
    // TODO shape may not always be defined?
    int origRank = Shape.rankFromShape(arg().getShape());
    SDVariable broadcastable = sameDiff.f().reductionBroadcastableWithOrigShape(origRank, dimensions, i_v1.get(0));
    SDVariable ret = sameDiff.onesLike(arg()).mul(broadcastable);
    return Arrays.asList(ret);
}
Also used : SDVariable(org.nd4j.autodiff.samediff.SDVariable)

Example 63 with SDVariable

use of org.nd4j.autodiff.samediff.SDVariable in project nd4j by deeplearning4j.

the class TensorMmul method doDiff.

@Override
public List<SDVariable> doDiff(List<SDVariable> i_v1) {
    List<SDVariable> ret = new ArrayList<>();
    int[] bAxes = range(0, rarg().getShape().length);
    int[] aAxes = range(0, larg().getShape().length);
    int aRank = larg().getShape().length;
    int bRank = rarg().getShape().length;
    int[][] sumAxes = new int[][] { mod(axes[0], aRank), mod(axes[1], bRank) };
    int[][] deletedAxes = new int[][] { removeIndex(aAxes, sumAxes[0]), removeIndex(bAxes, sumAxes[1]) };
    int[] gAxes = range(0, i_v1.get(0).getShape().length);
    int[][] firstAxes = new int[][] { Arrays.copyOfRange(gAxes, deletedAxes[0].length, gAxes.length), deletedAxes[1] };
    int[][] secondAxes = new int[][] { deletedAxes[0], Arrays.copyOfRange(gAxes, 0, deletedAxes[0].length) };
    // tensor matrix multiply gradient wrt second variable
    int[] firstPerm = argsort(combine(deletedAxes[0], keep(argsort(sumAxes[1]), sumAxes[0])));
    SDVariable firstResult = doTensorMmul(i_v1.get(0), rarg(), firstAxes);
    SDVariable permuted = f().permute(firstResult, firstPerm);
    ret.add(permuted);
    // tensor matrix multiply gradient wrt first variable
    int[] secondPerm = argsort(combine(keep(argsort(sumAxes[0]), sumAxes[1]), deletedAxes[1]));
    SDVariable secondResult = doTensorMmul(i_v1.get(0), larg(), secondAxes);
    SDVariable secondPermuted = f().permute(secondResult, secondPerm);
    ret.add(secondPermuted);
    return ret;
}
Also used : SDVariable(org.nd4j.autodiff.samediff.SDVariable)

Example 64 with SDVariable

use of org.nd4j.autodiff.samediff.SDVariable in project nd4j by deeplearning4j.

the class EuclideanDistance method doDiff.

@Override
public List<SDVariable> doDiff(List<SDVariable> i_v1) {
    // ddist(x,y)/dxi = (xi-yi)/dist(x,y)
    SDVariable euc = outputVariables()[0];
    SDVariable difference = larg().sub(rarg());
    SDVariable divBroadcastable;
    // TODO shape may not always be defined?
    int origRank = Shape.rankFromShape(arg().getShape());
    if (!(dimensions.length == 1 && dimensions[0] == Integer.MAX_VALUE)) {
        // 1x1 output case
        divBroadcastable = i_v1.get(0).div(euc);
    } else {
        divBroadcastable = f().reductionBroadcastableWithOrigShape(origRank, dimensions, i_v1.get(0).div(euc));
    }
    SDVariable gradX = difference.mul(divBroadcastable);
    SDVariable gradY = f().neg(gradX);
    return Arrays.asList(gradX, gradY);
}
Also used : SDVariable(org.nd4j.autodiff.samediff.SDVariable)

Example 65 with SDVariable

use of org.nd4j.autodiff.samediff.SDVariable in project nd4j by deeplearning4j.

the class Variance method doDiff.

@Override
public List<SDVariable> doDiff(List<SDVariable> i_v1) {
    // If out = var(in) then:
    // dL/dIn = dL/dOut * dOut/dIn
    // with dOut/dIn = (in-mean) * 2/(n-1)
    int n = f().getReductionLength(this);
    int origRank = Shape.rankFromShape(arg().getShape());
    SDVariable broadcastableMean = f().reductionBroadcastableWithOrigShape(origRank, dimensions, f().mean(arg(), dimensions));
    SDVariable broadcastableGrad = f().reductionBroadcastableWithOrigShape(origRank, dimensions, i_v1.get(0));
    SDVariable dOutdIn = arg().sub(broadcastableMean).mul(2.0 / (biasCorrected ? (n - 1) : n));
    SDVariable dLdIn = dOutdIn.mul(broadcastableGrad);
    return Arrays.asList(dLdIn);
}
Also used : SDVariable(org.nd4j.autodiff.samediff.SDVariable)

Aggregations

SDVariable (org.nd4j.autodiff.samediff.SDVariable)104 SameDiff (org.nd4j.autodiff.samediff.SameDiff)41 INDArray (org.nd4j.linalg.api.ndarray.INDArray)38 Test (org.junit.Test)36 ArrayList (java.util.ArrayList)18 DynamicCustomOp (org.nd4j.linalg.api.ops.DynamicCustomOp)10 lombok.val (lombok.val)7 LossFunctions (org.nd4j.autodiff.loss.LossFunctions)4 LossInfo (org.nd4j.autodiff.loss.LossInfo)4 BernoulliDistribution (org.nd4j.linalg.api.ops.random.impl.BernoulliDistribution)4 Ignore (org.junit.Ignore)3 DifferentialFunction (org.nd4j.autodiff.functions.DifferentialFunction)3 ND4JIllegalStateException (org.nd4j.linalg.exception.ND4JIllegalStateException)3 Triple (org.nd4j.linalg.primitives.Triple)2 DataOutputStream (java.io.DataOutputStream)1 FileOutputStream (java.io.FileOutputStream)1 ByteBuffer (java.nio.ByteBuffer)1 NoOpNameFoundException (org.nd4j.imports.NoOpNameFoundException)1 NdIndexIterator (org.nd4j.linalg.api.iter.NdIndexIterator)1 TruncateDivOp (org.nd4j.linalg.api.ops.impl.transforms.arithmetic.TruncateDivOp)1