Search in sources :

Example 96 with SDVariable

use of org.nd4j.autodiff.samediff.SDVariable in project nd4j by deeplearning4j.

the class GradCheckUtil method checkGradients.

public static boolean checkGradients(SameDiff sd, double eps, double maxRelError, double minAbsError, boolean print, boolean exitOnFirstFailure) {
    // Check data type:
    if (Nd4j.dataType() != DataBuffer.Type.DOUBLE) {
        throw new IllegalStateException("Data type must be set to double");
    }
    Set<String> fnOutputs = new HashSet<>();
    for (DifferentialFunction f : sd.functions()) {
        for (SDVariable s : f.outputVariables()) {
            fnOutputs.add(s.getVarName());
        }
    }
    // Check that all *input* SDVariables have arrays associated with them
    for (SDVariable s : sd.variables()) {
        if (fnOutputs.contains(s.getVarName())) {
            // This is not an input to the graph
            continue;
        }
        if (s.getArr() == null) {
            throw new IllegalStateException("Variable \"" + s.getVarName() + "\" does not have array associated with it");
        }
    }
    // Do forward pass, check that output is a scalar:
    INDArray out = sd.execAndEndResult();
    if (out.length() != 1) {
        throw new IllegalStateException("Output variable is not a scalar - has shape " + Arrays.toString(out.shape()));
    }
    // TODO also check that all inputs are non-zero (otherwise: consider out = sum(x * y) with all x and y being 0
    // in this case, gradients of x and y are all 0 too
    sd.execBackwards();
    Map<String, INDArray> grad = new HashMap<>();
    for (SDVariable v : sd.variables()) {
        if (fnOutputs.contains(v.getVarName())) {
            // This is not an input to the graph
            continue;
        }
        SDVariable g = sd.grad(v.getVarName());
        if (g == null) {
            throw new IllegalStateException("Null gradient variable for \"" + v.getVarName() + "\"");
        }
        INDArray ga = g.getArr();
        if (ga == null) {
            throw new IllegalStateException("Null gradient array encountered for variable: " + v.getVarName());
        }
        if (!Arrays.equals(v.getArr().shape(), g.getArr().shape())) {
            throw new IllegalStateException("Gradient shape does not match variable shape for variable \"" + v.getVarName() + "\": shape " + Arrays.toString(v.getArr().shape()) + " vs. gradient shape " + Arrays.toString(ga.shape()));
        }
        grad.put(v.getVarName(), ga.dup());
    }
    // Validate gradients for each variable:
    int totalNFailures = 0;
    int totalCount = 0;
    double maxError = 0.0;
    for (SDVariable s : sd.variables()) {
        if (fnOutputs.contains(s.getVarName())) {
            // This is not an input to the graph
            continue;
        }
        String name = s.getVarName();
        INDArray a = s.getArr();
        int n = a.length();
        if (print) {
            log.info("Starting test for variable \"{}\" with {} values", s.getVarName(), n);
        }
        NdIndexIterator iter = new NdIndexIterator('c', a.shape());
        int i = 0;
        while (iter.hasNext()) {
            int[] idx = iter.next();
            String strIdx = null;
            if (print) {
                strIdx = Arrays.toString(idx).replaceAll(" ", "");
            }
            totalCount++;
            double orig = a.getDouble(idx);
            a.putScalar(idx, orig + eps);
            double scorePlus = sd.execAndEndResult().getDouble(0);
            a.putScalar(idx, orig - eps);
            double scoreMinus = sd.execAndEndResult().getDouble(0);
            a.putScalar(idx, orig);
            double numericalGrad = (scorePlus - scoreMinus) / (2 * eps);
            double analyticGrad = grad.get(s.getVarName()).getDouble(idx);
            if (Double.isInfinite(numericalGrad) || Double.isNaN(numericalGrad)) {
                throw new IllegalStateException("Numerical gradient was " + numericalGrad + " for variable \"" + name + "\", parameter " + i + " of " + n + " (position: " + strIdx + ")");
            }
            if (Double.isInfinite(analyticGrad) || Double.isNaN(analyticGrad)) {
                throw new IllegalStateException("Analytic (SameDiff) gradient was " + analyticGrad + " for variable \"" + name + "\", parameter " + i + " of " + n + " (position: " + strIdx + ")");
            }
            double relError;
            if (numericalGrad == 0.0 && analyticGrad == 0.0) {
                relError = 0.0;
            } else {
                relError = Math.abs(analyticGrad - numericalGrad) / (Math.abs(Math.abs(analyticGrad) + Math.abs(numericalGrad)));
            }
            if (relError > maxError)
                maxError = relError;
            if (relError > maxRelError || Double.isNaN(relError)) {
                double absError = Math.abs(analyticGrad - numericalGrad);
                if (absError < minAbsError) {
                    if (print) {
                        log.info("Param " + i + " (" + name + strIdx + ") passed: grad= " + analyticGrad + ", numericalGrad= " + numericalGrad + ", relError= " + relError + "; absolute error = " + absError + " < minAbsoluteError = " + minAbsError);
                    }
                } else {
                    if (print)
                        log.info("Param " + i + " (" + name + strIdx + ") FAILED: grad= " + analyticGrad + ", numericalGrad= " + numericalGrad + ", relError= " + relError + ", absError=" + absError + ", scorePlus=" + scorePlus + ", scoreMinus= " + scoreMinus);
                    if (exitOnFirstFailure)
                        return false;
                    totalNFailures++;
                }
            } else if (print) {
                log.info("Param " + i + " (" + name + strIdx + ") passed: grad= " + analyticGrad + ", numericalGrad= " + numericalGrad + ", relError= " + relError);
            }
            i++;
        }
    }
    if (print) {
        int nPass = totalCount - totalNFailures;
        log.info("GradCheckUtil.checkGradients(): " + totalCount + " params checked, " + nPass + " passed, " + totalNFailures + " failed. Largest relative error = " + maxError);
    }
    return totalNFailures == 0;
}
Also used : NdIndexIterator(org.nd4j.linalg.api.iter.NdIndexIterator) SDVariable(org.nd4j.autodiff.samediff.SDVariable) INDArray(org.nd4j.linalg.api.ndarray.INDArray) DifferentialFunction(org.nd4j.autodiff.functions.DifferentialFunction)

Example 97 with SDVariable

use of org.nd4j.autodiff.samediff.SDVariable in project nd4j by deeplearning4j.

the class Reshape method doDiff.

@Override
public List<SDVariable> doDiff(List<SDVariable> i_v) {
    int[] origShape = arg().getShape();
    if (origShape == null) {
        // TODO need a more robust way to do this
        throw new ND4JIllegalStateException("Cannot reshape: original array input shape is null");
    }
    SDVariable ret = f().reshape(i_v.get(0), origShape);
    return Arrays.asList(ret);
}
Also used : SDVariable(org.nd4j.autodiff.samediff.SDVariable) ND4JIllegalStateException(org.nd4j.linalg.exception.ND4JIllegalStateException)

Example 98 with SDVariable

use of org.nd4j.autodiff.samediff.SDVariable in project nd4j by deeplearning4j.

the class ACosh method doDiff.

@Override
public List<SDVariable> doDiff(List<SDVariable> i_v) {
    // dacosh(x)/dx = 1/(sqrt(x^2-1)) -- note that domain is x >= 1
    SDVariable xSqPlus1 = sameDiff.square(arg()).sub(1.0);
    SDVariable sqrt = sameDiff.sqrt(xSqPlus1);
    return Arrays.asList(i_v.get(0).div(sqrt));
}
Also used : SDVariable(org.nd4j.autodiff.samediff.SDVariable)

Example 99 with SDVariable

use of org.nd4j.autodiff.samediff.SDVariable in project nd4j by deeplearning4j.

the class ASinh method doDiff.

@Override
public List<SDVariable> doDiff(List<SDVariable> i_v) {
    // dasinh(x)/dx = 1 / sqrt(x^2+1)
    SDVariable xSqPlus1 = f().square(arg()).add(1.0);
    SDVariable ret = i_v.get(0).div(f().sqrt(xSqPlus1));
    return Arrays.asList(ret);
}
Also used : SDVariable(org.nd4j.autodiff.samediff.SDVariable)

Example 100 with SDVariable

use of org.nd4j.autodiff.samediff.SDVariable in project nd4j by deeplearning4j.

the class ATan2 method doDiff.

@Override
public List<SDVariable> doDiff(List<SDVariable> i_v) {
    // Let z=atan2(r), with r=y/x
    // dz/dr = 1/(r^2+1), dr/dy = 1/x, dr/dx = -y/x^2
    SDVariable y = rarg();
    SDVariable x = larg();
    SDVariable r = y.div(x);
    SDVariable dOutdr = f().square(r).add(1.0).rdiv(1.0);
    SDVariable drdy = x.rdiv(1.0);
    SDVariable drdx = f().neg(y).div(f().square(x));
    SDVariable xGrad = dOutdr.mul(drdx).mul(i_v.get(0));
    SDVariable yGrad = dOutdr.mul(drdy).mul(i_v.get(0));
    return Arrays.asList(xGrad, yGrad);
}
Also used : SDVariable(org.nd4j.autodiff.samediff.SDVariable)

Aggregations

SDVariable (org.nd4j.autodiff.samediff.SDVariable)104 SameDiff (org.nd4j.autodiff.samediff.SameDiff)41 INDArray (org.nd4j.linalg.api.ndarray.INDArray)38 Test (org.junit.Test)36 ArrayList (java.util.ArrayList)18 DynamicCustomOp (org.nd4j.linalg.api.ops.DynamicCustomOp)10 lombok.val (lombok.val)7 LossFunctions (org.nd4j.autodiff.loss.LossFunctions)4 LossInfo (org.nd4j.autodiff.loss.LossInfo)4 BernoulliDistribution (org.nd4j.linalg.api.ops.random.impl.BernoulliDistribution)4 Ignore (org.junit.Ignore)3 DifferentialFunction (org.nd4j.autodiff.functions.DifferentialFunction)3 ND4JIllegalStateException (org.nd4j.linalg.exception.ND4JIllegalStateException)3 Triple (org.nd4j.linalg.primitives.Triple)2 DataOutputStream (java.io.DataOutputStream)1 FileOutputStream (java.io.FileOutputStream)1 ByteBuffer (java.nio.ByteBuffer)1 NoOpNameFoundException (org.nd4j.imports.NoOpNameFoundException)1 NdIndexIterator (org.nd4j.linalg.api.iter.NdIndexIterator)1 TruncateDivOp (org.nd4j.linalg.api.ops.impl.transforms.arithmetic.TruncateDivOp)1