Search in sources :

Example 1 with SDVariable

use of org.nd4j.autodiff.samediff.SDVariable in project nd4j by deeplearning4j.

the class Erfc method doDiff.

@Override
public List<SDVariable> doDiff(List<SDVariable> i_v) {
    // erfc(z) = 1 - erf(z)
    // Derivative of erf(z) is 2 / sqrt(pi) * e^(-z^2), so have to multiply by -1.
    SDVariable gradient = i_v.get(0);
    SDVariable constant = sameDiff.onesLike(gradient).mul(-2).div(Math.sqrt(Math.PI));
    SDVariable ret = constant.mul(sameDiff.exp(gradient.mul(gradient).mul(-1)));
    return Arrays.asList(ret);
}
Also used : SDVariable(org.nd4j.autodiff.samediff.SDVariable)

Example 2 with SDVariable

use of org.nd4j.autodiff.samediff.SDVariable in project nd4j by deeplearning4j.

the class OldAddOp method doDiff.

@Override
public List<SDVariable> doDiff(List<SDVariable> i_v) {
    SDVariable gradWrtX = f().div(i_v.get(0), rarg());
    SDVariable gradWrtY = f().mul(f().neg(gradWrtX), f().div(larg(), rarg()));
    List<SDVariable> ret = new ArrayList<>(2);
    ret.add(gradWrtX);
    ret.add(gradWrtY);
    return ret;
}
Also used : SDVariable(org.nd4j.autodiff.samediff.SDVariable) ArrayList(java.util.ArrayList)

Example 3 with SDVariable

use of org.nd4j.autodiff.samediff.SDVariable in project nd4j by deeplearning4j.

the class OldSubOp method doDiff.

@Override
public List<SDVariable> doDiff(List<SDVariable> i_v) {
    SDVariable gradWrtX = f().div(i_v.get(0), rarg());
    SDVariable gradWrtY = f().mul(f().neg(gradWrtX), f().div(larg(), rarg()));
    List<SDVariable> ret = new ArrayList<>(2);
    ret.add(gradWrtX);
    ret.add(gradWrtY);
    return ret;
}
Also used : SDVariable(org.nd4j.autodiff.samediff.SDVariable) ArrayList(java.util.ArrayList)

Example 4 with SDVariable

use of org.nd4j.autodiff.samediff.SDVariable in project nd4j by deeplearning4j.

the class ClipByValue method doDiff.

@Override
public List<SDVariable> doDiff(List<SDVariable> grad) {
    // dOut/dIn is 0 if clipped, 1 otherwise
    SDVariable out = outputVariables()[0];
    SDVariable notClippedLower = f().gt(arg(), clipValueMin);
    SDVariable notClippedUpper = f().lt(arg(), clipValueMax);
    SDVariable ret = notClippedLower.mul(notClippedUpper).mul(grad.get(0));
    return Arrays.asList(ret);
}
Also used : SDVariable(org.nd4j.autodiff.samediff.SDVariable)

Example 5 with SDVariable

use of org.nd4j.autodiff.samediff.SDVariable in project nd4j by deeplearning4j.

the class Set method doDiff.

@Override
public List<SDVariable> doDiff(List<SDVariable> i_v) {
    val shape = outputVariables()[0].getShape();
    SDVariable ym1 = f().rsub(rarg(), f().one(shape));
    SDVariable ret = f().mul(f().mul(rarg(), f().pow(larg(), 2.0)), larg());
    return Arrays.asList(ret);
}
Also used : lombok.val(lombok.val) SDVariable(org.nd4j.autodiff.samediff.SDVariable)

Aggregations

SDVariable (org.nd4j.autodiff.samediff.SDVariable)104 SameDiff (org.nd4j.autodiff.samediff.SameDiff)41 INDArray (org.nd4j.linalg.api.ndarray.INDArray)38 Test (org.junit.Test)36 ArrayList (java.util.ArrayList)18 DynamicCustomOp (org.nd4j.linalg.api.ops.DynamicCustomOp)10 lombok.val (lombok.val)7 LossFunctions (org.nd4j.autodiff.loss.LossFunctions)4 LossInfo (org.nd4j.autodiff.loss.LossInfo)4 BernoulliDistribution (org.nd4j.linalg.api.ops.random.impl.BernoulliDistribution)4 Ignore (org.junit.Ignore)3 DifferentialFunction (org.nd4j.autodiff.functions.DifferentialFunction)3 ND4JIllegalStateException (org.nd4j.linalg.exception.ND4JIllegalStateException)3 Triple (org.nd4j.linalg.primitives.Triple)2 DataOutputStream (java.io.DataOutputStream)1 FileOutputStream (java.io.FileOutputStream)1 ByteBuffer (java.nio.ByteBuffer)1 NoOpNameFoundException (org.nd4j.imports.NoOpNameFoundException)1 NdIndexIterator (org.nd4j.linalg.api.iter.NdIndexIterator)1 TruncateDivOp (org.nd4j.linalg.api.ops.impl.transforms.arithmetic.TruncateDivOp)1