use of org.nd4j.autodiff.samediff.SDVariable in project nd4j by deeplearning4j.
the class InvertPermutation method doDiff.
@Override
public List<SDVariable> doDiff(List<SDVariable> grad) {
SDVariable gradient = grad.get(0);
SDVariable invertedGradient = f().invertPermutation(gradient, false);
return Arrays.asList(invertedGradient);
}
use of org.nd4j.autodiff.samediff.SDVariable in project nd4j by deeplearning4j.
the class Cross method doDiff.
@Override
public List<SDVariable> doDiff(List<SDVariable> gradients) {
/**
* dL / dx = dL / dCross * dCross / dx
* dCross(a,b) / da = Cross(1, b)
* dCross(a,b) / db = Cross(a, 1)
*
* return (grad * Cross(1, b), grad * Cross(a, 1)
*/
SDVariable grad = gradients.get(0);
SDVariable a = larg();
SDVariable b = rarg();
SDVariable ones = sameDiff.onesLike(a);
SDVariable gradLeft = grad.mul(sameDiff.cross(ones, b));
SDVariable gradRight = grad.mul(sameDiff.cross(a, ones));
return Arrays.asList(gradLeft, gradRight);
}
use of org.nd4j.autodiff.samediff.SDVariable in project nd4j by deeplearning4j.
the class DiagPart method doDiff.
@Override
public List<SDVariable> doDiff(List<SDVariable> i_v) {
SDVariable grad = i_v.get(0);
SDVariable ret = sameDiff.diag(grad);
return Arrays.asList(ret);
}
use of org.nd4j.autodiff.samediff.SDVariable in project nd4j by deeplearning4j.
the class ACos method doDiff.
@Override
public List<SDVariable> doDiff(List<SDVariable> i_v) {
// dacos(x)/dx = -1 / sqrt(1-x^2)
SDVariable oneSubSq = f().square(arg()).rsub(1.0);
SDVariable sqrt = f().sqrt(oneSubSq);
SDVariable ret = sqrt.rdiv(-1.0).mul(i_v.get(0));
return Arrays.asList(ret);
}
use of org.nd4j.autodiff.samediff.SDVariable in project nd4j by deeplearning4j.
the class ASin method doDiff.
@Override
public List<SDVariable> doDiff(List<SDVariable> i_v) {
// d(asin(x))/dx = 1/sqrt(1-x^2)
SDVariable oneSubSq = sameDiff.square(arg()).rsub(1.0);
SDVariable ret = sameDiff.sqrt(oneSubSq).rdiv(1.0).mul(i_v.get(0));
return Arrays.asList(ret);
}
Aggregations