use of org.nd4j.autodiff.samediff.SDVariable in project nd4j by deeplearning4j.
the class Sqrt method doDiff.
@Override
public List<SDVariable> doDiff(List<SDVariable> i_v) {
SDVariable out = arg();
SDVariable g = sameDiff.pow(out, -0.5).mul(0.5).mul(i_v.get(0));
return Arrays.asList(g);
}
use of org.nd4j.autodiff.samediff.SDVariable in project nd4j by deeplearning4j.
the class OldRDivOp method doDiff.
@Override
public List<SDVariable> doDiff(List<SDVariable> i_v) {
SDVariable gradWrtX = f().div(i_v.get(0), larg());
SDVariable gradWrtY = f().mul(f().neg(gradWrtX), f().div(rarg(), larg()));
List<SDVariable> ret = new ArrayList<>(2);
ret.add(gradWrtX);
ret.add(gradWrtY);
return ret;
}
use of org.nd4j.autodiff.samediff.SDVariable in project nd4j by deeplearning4j.
the class TruncateDivOp method doDiff.
@Override
public List<SDVariable> doDiff(List<SDVariable> i_v) {
SDVariable gradWrtX = f().div(i_v.get(0), rarg());
SDVariable gradWrtY = f().mul(f().neg(gradWrtX), f().div(larg(), rarg()));
List<SDVariable> ret = new ArrayList<>(2);
ret.add(gradWrtX);
ret.add(gradWrtY);
return ret;
}
use of org.nd4j.autodiff.samediff.SDVariable in project nd4j by deeplearning4j.
the class DepthToSpace method doDiff.
@Override
public List<SDVariable> doDiff(List<SDVariable> i_v) {
// Gradient to DepthToSpace is just SpaceToDepth of same block size and data format.
SDVariable gradient = i_v.get(0);
SDVariable ret = sameDiff.spaceToDepth(gradient, blockSize, dataFormat);
return Arrays.asList(ret);
}
use of org.nd4j.autodiff.samediff.SDVariable in project nd4j by deeplearning4j.
the class TensorMmul method doTensorMmul.
private SDVariable doTensorMmul(SDVariable a, SDVariable b, int[][] axes) {
int validationLength = Math.min(axes[0].length, axes[1].length);
for (int i = 0; i < validationLength; i++) {
if (a.getShape()[axes[0][i]] != b.getShape()[axes[1][i]])
throw new IllegalArgumentException("Size of the given axes at each dimension must be the same size.");
if (axes[0][i] < 0)
axes[0][i] += a.getShape().length;
if (axes[1][i] < 0)
axes[1][i] += b.getShape().length;
}
List<Integer> listA = new ArrayList<>();
for (int i = 0; i < a.getShape().length; i++) {
if (!Ints.contains(axes[0], i))
listA.add(i);
}
int[] newAxesA = Ints.concat(Ints.toArray(listA), axes[0]);
List<Integer> listB = new ArrayList<>();
for (int i = 0; i < b.getShape().length; i++) {
if (!Ints.contains(axes[1], i))
listB.add(i);
}
int[] newAxesB = Ints.concat(axes[1], Ints.toArray(listB));
int n2 = 1;
int aLength = Math.min(a.getShape().length, axes[0].length);
for (int i = 0; i < aLength; i++) {
n2 *= a.getShape()[axes[0][i]];
}
// if listA and listB are empty these do not initialize.
// so initializing with {1} which will then get overridden if not empty
int[] newShapeA = { -1, n2 };
int[] oldShapeA;
if (listA.size() == 0) {
oldShapeA = new int[] { 1 };
} else {
oldShapeA = Ints.toArray(listA);
for (int i = 0; i < oldShapeA.length; i++) oldShapeA[i] = a.getShape()[oldShapeA[i]];
}
int n3 = 1;
int bNax = Math.min(b.getShape().length, axes[1].length);
for (int i = 0; i < bNax; i++) {
n3 *= b.getShape()[axes[1][i]];
}
int[] newShapeB = { n3, -1 };
int[] oldShapeB;
if (listB.size() == 0) {
oldShapeB = new int[] { 1 };
} else {
oldShapeB = Ints.toArray(listB);
for (int i = 0; i < oldShapeB.length; i++) oldShapeB[i] = b.getShape()[oldShapeB[i]];
}
SDVariable at = f().reshape(f().permute(a, newAxesA), newShapeA);
SDVariable bt = f().reshape(f().permute(b, newAxesB), newShapeB);
SDVariable ret = f().mmul(at, bt);
int[] aPlusB = Ints.concat(oldShapeA, oldShapeB);
return f().reshape(ret, aPlusB);
}
Aggregations