Search in sources :

Example 11 with SDVariable

use of org.nd4j.autodiff.samediff.SDVariable in project nd4j by deeplearning4j.

the class JaccardDistance method doDiff.

@Override
public List<SDVariable> doDiff(List<SDVariable> f1) {
    // Jaccard distance: https://en.wikipedia.org/wiki/Jaccard_index#Generalized_Jaccard_similarity_and_distance
    // J(x,y) = 1 - sum_i min(x_i, y_i) / sum_i max(x_i, y_i)
    int rank = Shape.rankFromShape(larg().getShape());
    // jaccard similarity = 1 - jaccard distance
    SDVariable jSim = outputVariables()[0].rsub(1.0);
    SDVariable min = f().min(larg(), rarg());
    SDVariable max = f().max(larg(), rarg());
    SDVariable sumMax = f().sum(max, dimensions);
    SDVariable broadcastableSumMax = f().reductionBroadcastableWithOrigShape(rank, dimensions, sumMax);
    SDVariable broadcastableJSim = f().reductionBroadcastableWithOrigShape(rank, dimensions, jSim);
    SDVariable xIsMin = f().eq(min, larg());
    SDVariable xIsMax = f().eq(max, larg());
    SDVariable yIsMin = f().eq(min, rarg());
    SDVariable yIsMax = f().eq(max, rarg());
    SDVariable dldx = xIsMax.mul(broadcastableJSim).sub(xIsMin).div(broadcastableSumMax);
    SDVariable dldy = yIsMax.mul(broadcastableJSim).sub(yIsMin).div(broadcastableSumMax);
    return Arrays.asList(dldx.mul(f1.get(0)), dldy.mul(f1.get(0)));
}
Also used : SDVariable(org.nd4j.autodiff.samediff.SDVariable)

Example 12 with SDVariable

use of org.nd4j.autodiff.samediff.SDVariable in project nd4j by deeplearning4j.

the class Mmul method doDiff.

@Override
public List<SDVariable> doDiff(List<SDVariable> i_v1) {
    List<SDVariable> ret = new ArrayList<>();
    SDVariable setup = sameDiff.setupFunction(i_v1.get(0));
    SDVariable gradWrtX = sameDiff.setupFunction(f().reshape(f().mmul(setup, rarg(), MMulTranspose.builder().transposeB(!mMulTranspose.isTransposeB()).transposeResult(mMulTranspose.isTransposeA()).build()), larg().getShape()));
    SDVariable gradWrtY = sameDiff.setupFunction(f().reshape(f().mmul(larg(), setup, MMulTranspose.builder().transposeA(!mMulTranspose.isTransposeA()).transposeResult(mMulTranspose.isTransposeB()).build()), rarg().getShape()));
    ret.add(gradWrtX);
    ret.add(gradWrtY);
    return ret;
}
Also used : SDVariable(org.nd4j.autodiff.samediff.SDVariable)

Example 13 with SDVariable

use of org.nd4j.autodiff.samediff.SDVariable in project nd4j by deeplearning4j.

the class NormMax method doDiff.

@Override
public List<SDVariable> doDiff(List<SDVariable> i_v1) {
    // maxnorm(in) = max_i |x_i|
    // d maxnorm(in)/dx = 0 if x_i is not the max, or d|x|/dx otherwise
    SDVariable absIn = sameDiff.abs(arg());
    SDVariable maxnorm = outputVariables()[0];
    // TODO shape may not always be defined?
    int origRank = Shape.rankFromShape(arg().getShape());
    SDVariable maxnormBc = f().reductionBroadcastableWithOrigShape(origRank, dimensions, maxnorm);
    maxnormBc = sameDiff.onesLike(arg()).mul(maxnormBc);
    SDVariable eq = sameDiff.eq(absIn, maxnormBc);
    SDVariable dAbsXdX = sameDiff.sign(arg());
    SDVariable dNormmaxDx = eq.mul(dAbsXdX);
    SDVariable broadcastableGrad = f().reductionBroadcastableWithOrigShape(origRank, dimensions, i_v1.get(0));
    SDVariable ret = dNormmaxDx.mul(broadcastableGrad);
    return Arrays.asList(ret);
}
Also used : SDVariable(org.nd4j.autodiff.samediff.SDVariable)

Example 14 with SDVariable

use of org.nd4j.autodiff.samediff.SDVariable in project nd4j by deeplearning4j.

the class Prod method doDiff.

@Override
public List<SDVariable> doDiff(List<SDVariable> i_v1) {
    SDVariable prod = outputVariables()[0];
    // TODO shape may not always be defined?
    int origRank = Shape.rankFromShape(arg().getShape());
    SDVariable broadcastableGrad = sameDiff.f().reductionBroadcastableWithOrigShape(origRank, dimensions, i_v1.get(0));
    SDVariable broadcastableProd = sameDiff.f().reductionBroadcastableWithOrigShape(origRank, dimensions, prod);
    SDVariable mul = broadcastableGrad.div(arg());
    SDVariable ret = broadcastableProd.mul(mul);
    return Arrays.asList(ret);
}
Also used : SDVariable(org.nd4j.autodiff.samediff.SDVariable)

Example 15 with SDVariable

use of org.nd4j.autodiff.samediff.SDVariable in project nd4j by deeplearning4j.

the class StandardDeviation method doDiff.

@Override
public List<SDVariable> doDiff(List<SDVariable> i_v1) {
    // Here: calculating dL/dIn given dL/dOut (i.e., i_v1) and input/output
    // If out = stdev(in) then:
    // dL/dIn = dL/dOut * dOut/dIn
    // dOut/dIn_i = (in_i-mean)/(stdev * (n-1))
    int origRank = Shape.rankFromShape(arg().getShape());
    int n = f().getReductionLength(this);
    SDVariable broadcastableStdevOut = f().reductionBroadcastableWithOrigShape(origRank, dimensions, outputVariables()[0]);
    SDVariable broadcastableMean = f().reductionBroadcastableWithOrigShape(origRank, dimensions, f().mean(arg(), dimensions));
    SDVariable diff = arg().sub(broadcastableMean);
    SDVariable dOutdIn = diff.div(broadcastableStdevOut);
    if (this.biasCorrected) {
        dOutdIn = dOutdIn.div(n - 1);
    } else {
        dOutdIn = dOutdIn.div(n);
    }
    SDVariable broadcastableGrad = f().reductionBroadcastableWithOrigShape(origRank, dimensions, i_v1.get(0));
    SDVariable dLdIn = dOutdIn.mul(broadcastableGrad);
    return Arrays.asList(dLdIn);
}
Also used : SDVariable(org.nd4j.autodiff.samediff.SDVariable)

Aggregations

SDVariable (org.nd4j.autodiff.samediff.SDVariable)104 SameDiff (org.nd4j.autodiff.samediff.SameDiff)41 INDArray (org.nd4j.linalg.api.ndarray.INDArray)38 Test (org.junit.Test)36 ArrayList (java.util.ArrayList)18 DynamicCustomOp (org.nd4j.linalg.api.ops.DynamicCustomOp)10 lombok.val (lombok.val)7 LossFunctions (org.nd4j.autodiff.loss.LossFunctions)4 LossInfo (org.nd4j.autodiff.loss.LossInfo)4 BernoulliDistribution (org.nd4j.linalg.api.ops.random.impl.BernoulliDistribution)4 Ignore (org.junit.Ignore)3 DifferentialFunction (org.nd4j.autodiff.functions.DifferentialFunction)3 ND4JIllegalStateException (org.nd4j.linalg.exception.ND4JIllegalStateException)3 Triple (org.nd4j.linalg.primitives.Triple)2 DataOutputStream (java.io.DataOutputStream)1 FileOutputStream (java.io.FileOutputStream)1 ByteBuffer (java.nio.ByteBuffer)1 NoOpNameFoundException (org.nd4j.imports.NoOpNameFoundException)1 NdIndexIterator (org.nd4j.linalg.api.iter.NdIndexIterator)1 TruncateDivOp (org.nd4j.linalg.api.ops.impl.transforms.arithmetic.TruncateDivOp)1