use of org.nd4j.autodiff.samediff.SDVariable in project nd4j by deeplearning4j.
the class GradCheckUtil method checkGradients.
public static boolean checkGradients(SameDiff sd, double eps, double maxRelError, double minAbsError, boolean print, boolean exitOnFirstFailure) {
// Check data type:
if (Nd4j.dataType() != DataBuffer.Type.DOUBLE) {
throw new IllegalStateException("Data type must be set to double");
}
Set<String> fnOutputs = new HashSet<>();
for (DifferentialFunction f : sd.functions()) {
for (SDVariable s : f.outputVariables()) {
fnOutputs.add(s.getVarName());
}
}
// Check that all *input* SDVariables have arrays associated with them
for (SDVariable s : sd.variables()) {
if (fnOutputs.contains(s.getVarName())) {
// This is not an input to the graph
continue;
}
if (s.getArr() == null) {
throw new IllegalStateException("Variable \"" + s.getVarName() + "\" does not have array associated with it");
}
}
// Do forward pass, check that output is a scalar:
INDArray out = sd.execAndEndResult();
if (out.length() != 1) {
throw new IllegalStateException("Output variable is not a scalar - has shape " + Arrays.toString(out.shape()));
}
// TODO also check that all inputs are non-zero (otherwise: consider out = sum(x * y) with all x and y being 0
// in this case, gradients of x and y are all 0 too
sd.execBackwards();
Map<String, INDArray> grad = new HashMap<>();
for (SDVariable v : sd.variables()) {
if (fnOutputs.contains(v.getVarName())) {
// This is not an input to the graph
continue;
}
SDVariable g = sd.grad(v.getVarName());
if (g == null) {
throw new IllegalStateException("Null gradient variable for \"" + v.getVarName() + "\"");
}
INDArray ga = g.getArr();
if (ga == null) {
throw new IllegalStateException("Null gradient array encountered for variable: " + v.getVarName());
}
if (!Arrays.equals(v.getArr().shape(), g.getArr().shape())) {
throw new IllegalStateException("Gradient shape does not match variable shape for variable \"" + v.getVarName() + "\": shape " + Arrays.toString(v.getArr().shape()) + " vs. gradient shape " + Arrays.toString(ga.shape()));
}
grad.put(v.getVarName(), ga.dup());
}
// Validate gradients for each variable:
int totalNFailures = 0;
int totalCount = 0;
double maxError = 0.0;
for (SDVariable s : sd.variables()) {
if (fnOutputs.contains(s.getVarName())) {
// This is not an input to the graph
continue;
}
String name = s.getVarName();
INDArray a = s.getArr();
int n = a.length();
if (print) {
log.info("Starting test for variable \"{}\" with {} values", s.getVarName(), n);
}
NdIndexIterator iter = new NdIndexIterator('c', a.shape());
int i = 0;
while (iter.hasNext()) {
int[] idx = iter.next();
String strIdx = null;
if (print) {
strIdx = Arrays.toString(idx).replaceAll(" ", "");
}
totalCount++;
double orig = a.getDouble(idx);
a.putScalar(idx, orig + eps);
double scorePlus = sd.execAndEndResult().getDouble(0);
a.putScalar(idx, orig - eps);
double scoreMinus = sd.execAndEndResult().getDouble(0);
a.putScalar(idx, orig);
double numericalGrad = (scorePlus - scoreMinus) / (2 * eps);
double analyticGrad = grad.get(s.getVarName()).getDouble(idx);
if (Double.isInfinite(numericalGrad) || Double.isNaN(numericalGrad)) {
throw new IllegalStateException("Numerical gradient was " + numericalGrad + " for variable \"" + name + "\", parameter " + i + " of " + n + " (position: " + strIdx + ")");
}
if (Double.isInfinite(analyticGrad) || Double.isNaN(analyticGrad)) {
throw new IllegalStateException("Analytic (SameDiff) gradient was " + analyticGrad + " for variable \"" + name + "\", parameter " + i + " of " + n + " (position: " + strIdx + ")");
}
double relError;
if (numericalGrad == 0.0 && analyticGrad == 0.0) {
relError = 0.0;
} else {
relError = Math.abs(analyticGrad - numericalGrad) / (Math.abs(Math.abs(analyticGrad) + Math.abs(numericalGrad)));
}
if (relError > maxError)
maxError = relError;
if (relError > maxRelError || Double.isNaN(relError)) {
double absError = Math.abs(analyticGrad - numericalGrad);
if (absError < minAbsError) {
if (print) {
log.info("Param " + i + " (" + name + strIdx + ") passed: grad= " + analyticGrad + ", numericalGrad= " + numericalGrad + ", relError= " + relError + "; absolute error = " + absError + " < minAbsoluteError = " + minAbsError);
}
} else {
if (print)
log.info("Param " + i + " (" + name + strIdx + ") FAILED: grad= " + analyticGrad + ", numericalGrad= " + numericalGrad + ", relError= " + relError + ", absError=" + absError + ", scorePlus=" + scorePlus + ", scoreMinus= " + scoreMinus);
if (exitOnFirstFailure)
return false;
totalNFailures++;
}
} else if (print) {
log.info("Param " + i + " (" + name + strIdx + ") passed: grad= " + analyticGrad + ", numericalGrad= " + numericalGrad + ", relError= " + relError);
}
i++;
}
}
if (print) {
int nPass = totalCount - totalNFailures;
log.info("GradCheckUtil.checkGradients(): " + totalCount + " params checked, " + nPass + " passed, " + totalNFailures + " failed. Largest relative error = " + maxError);
}
return totalNFailures == 0;
}
use of org.nd4j.autodiff.samediff.SDVariable in project nd4j by deeplearning4j.
the class Reshape method doDiff.
@Override
public List<SDVariable> doDiff(List<SDVariable> i_v) {
int[] origShape = arg().getShape();
if (origShape == null) {
// TODO need a more robust way to do this
throw new ND4JIllegalStateException("Cannot reshape: original array input shape is null");
}
SDVariable ret = f().reshape(i_v.get(0), origShape);
return Arrays.asList(ret);
}
use of org.nd4j.autodiff.samediff.SDVariable in project nd4j by deeplearning4j.
the class ACosh method doDiff.
@Override
public List<SDVariable> doDiff(List<SDVariable> i_v) {
// dacosh(x)/dx = 1/(sqrt(x^2-1)) -- note that domain is x >= 1
SDVariable xSqPlus1 = sameDiff.square(arg()).sub(1.0);
SDVariable sqrt = sameDiff.sqrt(xSqPlus1);
return Arrays.asList(i_v.get(0).div(sqrt));
}
use of org.nd4j.autodiff.samediff.SDVariable in project nd4j by deeplearning4j.
the class ASinh method doDiff.
@Override
public List<SDVariable> doDiff(List<SDVariable> i_v) {
// dasinh(x)/dx = 1 / sqrt(x^2+1)
SDVariable xSqPlus1 = f().square(arg()).add(1.0);
SDVariable ret = i_v.get(0).div(f().sqrt(xSqPlus1));
return Arrays.asList(ret);
}
use of org.nd4j.autodiff.samediff.SDVariable in project nd4j by deeplearning4j.
the class ATan2 method doDiff.
@Override
public List<SDVariable> doDiff(List<SDVariable> i_v) {
// Let z=atan2(r), with r=y/x
// dz/dr = 1/(r^2+1), dr/dy = 1/x, dr/dx = -y/x^2
SDVariable y = rarg();
SDVariable x = larg();
SDVariable r = y.div(x);
SDVariable dOutdr = f().square(r).add(1.0).rdiv(1.0);
SDVariable drdy = x.rdiv(1.0);
SDVariable drdx = f().neg(y).div(f().square(x));
SDVariable xGrad = dOutdr.mul(drdx).mul(i_v.get(0));
SDVariable yGrad = dOutdr.mul(drdy).mul(i_v.get(0));
return Arrays.asList(xGrad, yGrad);
}
Aggregations