use of org.nd4j.autodiff.samediff.SDVariable in project nd4j by deeplearning4j.
the class ATan method doDiff.
@Override
public List<SDVariable> doDiff(List<SDVariable> i_v) {
// d(atan(x))/dx = 1/(x^2+1)
SDVariable xSqPlus1 = f().square(arg()).add(1.0);
SDVariable ret = xSqPlus1.rdiv(1.0).mul(i_v.get(0));
return Arrays.asList(ret);
}
use of org.nd4j.autodiff.samediff.SDVariable in project nd4j by deeplearning4j.
the class ATanh method doDiff.
@Override
public List<SDVariable> doDiff(List<SDVariable> i_v) {
// d(atanh(x))/dx = 1 / (1-x^2)
SDVariable oneMinusX2 = sameDiff.square(arg()).rsub(1.0);
SDVariable ret = oneMinusX2.rdiv(1.0).mul(i_v.get(0));
return Arrays.asList(ret);
}
use of org.nd4j.autodiff.samediff.SDVariable in project nd4j by deeplearning4j.
the class Log method doDiff.
@Override
public List<SDVariable> doDiff(List<SDVariable> i_v) {
f().validateDifferentialFunctionsameDiff(arg());
SDVariable toInverse = sameDiff.setupFunction(f().div(i_v.get(0), arg()));
return Arrays.asList(toInverse);
}
use of org.nd4j.autodiff.samediff.SDVariable in project nd4j by deeplearning4j.
the class Tan method doDiff.
@Override
public List<SDVariable> doDiff(List<SDVariable> i_v) {
// d(tan(x))/dx = (sec(x))^2 = 1 / (cos(x))^2
SDVariable oneDivCos2 = sameDiff.square(sameDiff.cos(arg())).rdiv(1.0);
SDVariable ret = oneDivCos2.mul(i_v.get(0));
return Arrays.asList(ret);
}
use of org.nd4j.autodiff.samediff.SDVariable in project nd4j by deeplearning4j.
the class ClipByNorm method doDiff.
@Override
public List<SDVariable> doDiff(List<SDVariable> grad) {
// dOut/dIn is ??? if clipped, 1 otherwise
int origRank = Shape.rankFromShape(arg().getShape());
SDVariable l2norm = f().norm2(arg(), dimensions);
SDVariable broadcastableNorm = f().reductionBroadcastableWithOrigShape(origRank, dimensions, l2norm);
SDVariable isClippedBC = f().gte(broadcastableNorm, clipValue);
SDVariable notClippedBC = isClippedBC.rsub(1.0);
// SDVariable dnormdx = arg().div(broadcastableNorm);
// SDVariable sqNorm = f().square(broadcastableNorm);
// SDVariable dOutdInClipped = sqNorm.rdiv(-1).mul(dnormdx).mul(arg()) //-1/(norm2(x))^2 * x/norm2(x)
// .add(broadcastableNorm.rdiv(1.0))
// .mul(clipValue);
SDVariable dOutdInClipped = // -x^2/(norm2(x))^3
f().neg(f().square(arg()).div(f().cube(broadcastableNorm))).add(// + 1/norm(x)
broadcastableNorm.rdiv(1.0)).mul(clipValue).mul(isClippedBC);
SDVariable ret = notClippedBC.add(dOutdInClipped).mul(grad.get(0));
return Arrays.asList(ret);
}
Aggregations