use of org.nd4j.autodiff.samediff.SDVariable in project nd4j by deeplearning4j.
the class JaccardDistance method doDiff.
@Override
public List<SDVariable> doDiff(List<SDVariable> f1) {
// Jaccard distance: https://en.wikipedia.org/wiki/Jaccard_index#Generalized_Jaccard_similarity_and_distance
// J(x,y) = 1 - sum_i min(x_i, y_i) / sum_i max(x_i, y_i)
int rank = Shape.rankFromShape(larg().getShape());
// jaccard similarity = 1 - jaccard distance
SDVariable jSim = outputVariables()[0].rsub(1.0);
SDVariable min = f().min(larg(), rarg());
SDVariable max = f().max(larg(), rarg());
SDVariable sumMax = f().sum(max, dimensions);
SDVariable broadcastableSumMax = f().reductionBroadcastableWithOrigShape(rank, dimensions, sumMax);
SDVariable broadcastableJSim = f().reductionBroadcastableWithOrigShape(rank, dimensions, jSim);
SDVariable xIsMin = f().eq(min, larg());
SDVariable xIsMax = f().eq(max, larg());
SDVariable yIsMin = f().eq(min, rarg());
SDVariable yIsMax = f().eq(max, rarg());
SDVariable dldx = xIsMax.mul(broadcastableJSim).sub(xIsMin).div(broadcastableSumMax);
SDVariable dldy = yIsMax.mul(broadcastableJSim).sub(yIsMin).div(broadcastableSumMax);
return Arrays.asList(dldx.mul(f1.get(0)), dldy.mul(f1.get(0)));
}
use of org.nd4j.autodiff.samediff.SDVariable in project nd4j by deeplearning4j.
the class Mmul method doDiff.
@Override
public List<SDVariable> doDiff(List<SDVariable> i_v1) {
List<SDVariable> ret = new ArrayList<>();
SDVariable setup = sameDiff.setupFunction(i_v1.get(0));
SDVariable gradWrtX = sameDiff.setupFunction(f().reshape(f().mmul(setup, rarg(), MMulTranspose.builder().transposeB(!mMulTranspose.isTransposeB()).transposeResult(mMulTranspose.isTransposeA()).build()), larg().getShape()));
SDVariable gradWrtY = sameDiff.setupFunction(f().reshape(f().mmul(larg(), setup, MMulTranspose.builder().transposeA(!mMulTranspose.isTransposeA()).transposeResult(mMulTranspose.isTransposeB()).build()), rarg().getShape()));
ret.add(gradWrtX);
ret.add(gradWrtY);
return ret;
}
use of org.nd4j.autodiff.samediff.SDVariable in project nd4j by deeplearning4j.
the class NormMax method doDiff.
@Override
public List<SDVariable> doDiff(List<SDVariable> i_v1) {
// maxnorm(in) = max_i |x_i|
// d maxnorm(in)/dx = 0 if x_i is not the max, or d|x|/dx otherwise
SDVariable absIn = sameDiff.abs(arg());
SDVariable maxnorm = outputVariables()[0];
// TODO shape may not always be defined?
int origRank = Shape.rankFromShape(arg().getShape());
SDVariable maxnormBc = f().reductionBroadcastableWithOrigShape(origRank, dimensions, maxnorm);
maxnormBc = sameDiff.onesLike(arg()).mul(maxnormBc);
SDVariable eq = sameDiff.eq(absIn, maxnormBc);
SDVariable dAbsXdX = sameDiff.sign(arg());
SDVariable dNormmaxDx = eq.mul(dAbsXdX);
SDVariable broadcastableGrad = f().reductionBroadcastableWithOrigShape(origRank, dimensions, i_v1.get(0));
SDVariable ret = dNormmaxDx.mul(broadcastableGrad);
return Arrays.asList(ret);
}
use of org.nd4j.autodiff.samediff.SDVariable in project nd4j by deeplearning4j.
the class Prod method doDiff.
@Override
public List<SDVariable> doDiff(List<SDVariable> i_v1) {
SDVariable prod = outputVariables()[0];
// TODO shape may not always be defined?
int origRank = Shape.rankFromShape(arg().getShape());
SDVariable broadcastableGrad = sameDiff.f().reductionBroadcastableWithOrigShape(origRank, dimensions, i_v1.get(0));
SDVariable broadcastableProd = sameDiff.f().reductionBroadcastableWithOrigShape(origRank, dimensions, prod);
SDVariable mul = broadcastableGrad.div(arg());
SDVariable ret = broadcastableProd.mul(mul);
return Arrays.asList(ret);
}
use of org.nd4j.autodiff.samediff.SDVariable in project nd4j by deeplearning4j.
the class StandardDeviation method doDiff.
@Override
public List<SDVariable> doDiff(List<SDVariable> i_v1) {
// Here: calculating dL/dIn given dL/dOut (i.e., i_v1) and input/output
// If out = stdev(in) then:
// dL/dIn = dL/dOut * dOut/dIn
// dOut/dIn_i = (in_i-mean)/(stdev * (n-1))
int origRank = Shape.rankFromShape(arg().getShape());
int n = f().getReductionLength(this);
SDVariable broadcastableStdevOut = f().reductionBroadcastableWithOrigShape(origRank, dimensions, outputVariables()[0]);
SDVariable broadcastableMean = f().reductionBroadcastableWithOrigShape(origRank, dimensions, f().mean(arg(), dimensions));
SDVariable diff = arg().sub(broadcastableMean);
SDVariable dOutdIn = diff.div(broadcastableStdevOut);
if (this.biasCorrected) {
dOutdIn = dOutdIn.div(n - 1);
} else {
dOutdIn = dOutdIn.div(n);
}
SDVariable broadcastableGrad = f().reductionBroadcastableWithOrigShape(origRank, dimensions, i_v1.get(0));
SDVariable dLdIn = dOutdIn.mul(broadcastableGrad);
return Arrays.asList(dLdIn);
}
Aggregations