Search in sources :

Example 16 with DifferentialFunction

use of org.nd4j.autodiff.functions.DifferentialFunction in project nd4j by deeplearning4j.

the class GradCheckUtil method checkGradients.

public static boolean checkGradients(SameDiff sd, double eps, double maxRelError, double minAbsError, boolean print, boolean exitOnFirstFailure) {
    // Check data type:
    if (Nd4j.dataType() != DataBuffer.Type.DOUBLE) {
        throw new IllegalStateException("Data type must be set to double");
    }
    Set<String> fnOutputs = new HashSet<>();
    for (DifferentialFunction f : sd.functions()) {
        for (SDVariable s : f.outputVariables()) {
            fnOutputs.add(s.getVarName());
        }
    }
    // Check that all *input* SDVariables have arrays associated with them
    for (SDVariable s : sd.variables()) {
        if (fnOutputs.contains(s.getVarName())) {
            // This is not an input to the graph
            continue;
        }
        if (s.getArr() == null) {
            throw new IllegalStateException("Variable \"" + s.getVarName() + "\" does not have array associated with it");
        }
    }
    // Do forward pass, check that output is a scalar:
    INDArray out = sd.execAndEndResult();
    if (out.length() != 1) {
        throw new IllegalStateException("Output variable is not a scalar - has shape " + Arrays.toString(out.shape()));
    }
    // TODO also check that all inputs are non-zero (otherwise: consider out = sum(x * y) with all x and y being 0
    // in this case, gradients of x and y are all 0 too
    sd.execBackwards();
    Map<String, INDArray> grad = new HashMap<>();
    for (SDVariable v : sd.variables()) {
        if (fnOutputs.contains(v.getVarName())) {
            // This is not an input to the graph
            continue;
        }
        SDVariable g = sd.grad(v.getVarName());
        if (g == null) {
            throw new IllegalStateException("Null gradient variable for \"" + v.getVarName() + "\"");
        }
        INDArray ga = g.getArr();
        if (ga == null) {
            throw new IllegalStateException("Null gradient array encountered for variable: " + v.getVarName());
        }
        if (!Arrays.equals(v.getArr().shape(), g.getArr().shape())) {
            throw new IllegalStateException("Gradient shape does not match variable shape for variable \"" + v.getVarName() + "\": shape " + Arrays.toString(v.getArr().shape()) + " vs. gradient shape " + Arrays.toString(ga.shape()));
        }
        grad.put(v.getVarName(), ga.dup());
    }
    // Validate gradients for each variable:
    int totalNFailures = 0;
    int totalCount = 0;
    double maxError = 0.0;
    for (SDVariable s : sd.variables()) {
        if (fnOutputs.contains(s.getVarName())) {
            // This is not an input to the graph
            continue;
        }
        String name = s.getVarName();
        INDArray a = s.getArr();
        int n = a.length();
        if (print) {
            log.info("Starting test for variable \"{}\" with {} values", s.getVarName(), n);
        }
        NdIndexIterator iter = new NdIndexIterator('c', a.shape());
        int i = 0;
        while (iter.hasNext()) {
            int[] idx = iter.next();
            String strIdx = null;
            if (print) {
                strIdx = Arrays.toString(idx).replaceAll(" ", "");
            }
            totalCount++;
            double orig = a.getDouble(idx);
            a.putScalar(idx, orig + eps);
            double scorePlus = sd.execAndEndResult().getDouble(0);
            a.putScalar(idx, orig - eps);
            double scoreMinus = sd.execAndEndResult().getDouble(0);
            a.putScalar(idx, orig);
            double numericalGrad = (scorePlus - scoreMinus) / (2 * eps);
            double analyticGrad = grad.get(s.getVarName()).getDouble(idx);
            if (Double.isInfinite(numericalGrad) || Double.isNaN(numericalGrad)) {
                throw new IllegalStateException("Numerical gradient was " + numericalGrad + " for variable \"" + name + "\", parameter " + i + " of " + n + " (position: " + strIdx + ")");
            }
            if (Double.isInfinite(analyticGrad) || Double.isNaN(analyticGrad)) {
                throw new IllegalStateException("Analytic (SameDiff) gradient was " + analyticGrad + " for variable \"" + name + "\", parameter " + i + " of " + n + " (position: " + strIdx + ")");
            }
            double relError;
            if (numericalGrad == 0.0 && analyticGrad == 0.0) {
                relError = 0.0;
            } else {
                relError = Math.abs(analyticGrad - numericalGrad) / (Math.abs(Math.abs(analyticGrad) + Math.abs(numericalGrad)));
            }
            if (relError > maxError)
                maxError = relError;
            if (relError > maxRelError || Double.isNaN(relError)) {
                double absError = Math.abs(analyticGrad - numericalGrad);
                if (absError < minAbsError) {
                    if (print) {
                        log.info("Param " + i + " (" + name + strIdx + ") passed: grad= " + analyticGrad + ", numericalGrad= " + numericalGrad + ", relError= " + relError + "; absolute error = " + absError + " < minAbsoluteError = " + minAbsError);
                    }
                } else {
                    if (print)
                        log.info("Param " + i + " (" + name + strIdx + ") FAILED: grad= " + analyticGrad + ", numericalGrad= " + numericalGrad + ", relError= " + relError + ", absError=" + absError + ", scorePlus=" + scorePlus + ", scoreMinus= " + scoreMinus);
                    if (exitOnFirstFailure)
                        return false;
                    totalNFailures++;
                }
            } else if (print) {
                log.info("Param " + i + " (" + name + strIdx + ") passed: grad= " + analyticGrad + ", numericalGrad= " + numericalGrad + ", relError= " + relError);
            }
            i++;
        }
    }
    if (print) {
        int nPass = totalCount - totalNFailures;
        log.info("GradCheckUtil.checkGradients(): " + totalCount + " params checked, " + nPass + " passed, " + totalNFailures + " failed. Largest relative error = " + maxError);
    }
    return totalNFailures == 0;
}
Also used : NdIndexIterator(org.nd4j.linalg.api.iter.NdIndexIterator) SDVariable(org.nd4j.autodiff.samediff.SDVariable) INDArray(org.nd4j.linalg.api.ndarray.INDArray) DifferentialFunction(org.nd4j.autodiff.functions.DifferentialFunction)

Example 17 with DifferentialFunction

use of org.nd4j.autodiff.functions.DifferentialFunction in project nd4j by deeplearning4j.

the class SameDiff method execBackwardAndEndResult.

/**
 * Exec a backwards operation
 * and return the end result
 *
 * @return
 */
public INDArray execBackwardAndEndResult() {
    List<DifferentialFunction> backwards = execBackwards().getRight();
    DifferentialFunction df = backwards.get(backwards.size() - 1);
    if (df instanceof Op) {
        return ((Op) df).z();
    } else if (df instanceof DynamicCustomOp) {
        return ((DynamicCustomOp) df).getOutputArgument(0);
    } else {
        return null;
    }
}
Also used : DifferentialFunction(org.nd4j.autodiff.functions.DifferentialFunction)

Example 18 with DifferentialFunction

use of org.nd4j.autodiff.functions.DifferentialFunction in project nd4j by deeplearning4j.

the class SameDiff method invokeGraphOn.

/**
 * @param sameDiff
 * @return
 */
public SDVariable invokeGraphOn(SameDiff sameDiff) {
    // map the new vertices on to the old ones
    Map<Integer, Integer> thisVertexIdToNew = new HashMap<>();
    int idx = 1;
    for (val var : variables()) {
        val clone = cloner.deepCloneDontCloneInstances(var, var.getSameDiff());
        val newVar = sameDiff.var(clone);
        if (var.getArr() != null) {
            sameDiff.associateArrayWithVariable(var.getArr(), newVar);
        }
        thisVertexIdToNew.put(idx, idx);
        clone.setSameDiff(sameDiff);
        idx++;
    }
    val newFunctions = new LinkedHashMap<String, DifferentialFunction>();
    for (DifferentialFunction function : functionInstancesById.values()) {
        if (function instanceof SDVariable) {
            continue;
        }
        DifferentialFunction clone = cloner.deepCloneDontCloneInstances(function, function.getSameDiff());
        clone.setSameDiff(sameDiff);
        clone.setOwnName(function.getOwnName());
        if (sameDiff.functionExists(function.getOwnName()))
            sameDiff.putFunctionForId(function.getOwnName(), function);
        newFunctions.put(function.getOwnName(), clone);
        val argsForFunction = function.args();
        val outputsForFunction = function.outputVariables();
        // note that these have the same variable names
        sameDiff.addArgsFor(argsForFunction, clone);
        sameDiff.addOutgoingFor(outputsForFunction, function);
        for (val arg : clone.args()) {
            arg.setSameDiff(sameDiff);
        }
        for (val output : clone.outputVariables()) {
            output.setSameDiff(sameDiff);
        }
        sameDiff.functionInstancesById.put(function.getOwnName(), function);
    }
    for (val reverseArrayEntry : reverseArrayLookup.entrySet()) {
        sameDiff.reverseArrayLookup.put(reverseArrayEntry.getKey(), sameDiff.getVariable(reverseArrayEntry.getValue().getVarName()));
    }
    return sameDiff.variables().get(sameDiff.variables().size() - 1);
}
Also used : AtomicInteger(java.util.concurrent.atomic.AtomicInteger) DifferentialFunction(org.nd4j.autodiff.functions.DifferentialFunction)

Aggregations

DifferentialFunction (org.nd4j.autodiff.functions.DifferentialFunction)18 INDArray (org.nd4j.linalg.api.ndarray.INDArray)9 Test (org.junit.Test)7 ND4JIllegalStateException (org.nd4j.linalg.exception.ND4JIllegalStateException)6 ArrayList (java.util.ArrayList)3 lombok.val (lombok.val)3 SDVariable (org.nd4j.autodiff.samediff.SDVariable)3 SameDiff (org.nd4j.autodiff.samediff.SameDiff)2 NoOpNameFoundException (org.nd4j.imports.NoOpNameFoundException)2 DynamicCustomOp (org.nd4j.linalg.api.ops.DynamicCustomOp)2 Op (org.nd4j.linalg.api.ops.Op)2 GradientBackwardsMarker (org.nd4j.linalg.api.ops.impl.transforms.gradient.GradientBackwardsMarker)2 IntArrayKeyMap (org.nd4j.linalg.collection.IntArrayKeyMap)2 Set (java.util.Set)1 AtomicInteger (java.util.concurrent.atomic.AtomicInteger)1 FlowPath (org.nd4j.autodiff.samediff.flow.FlowPath)1 NdIndexIterator (org.nd4j.linalg.api.iter.NdIndexIterator)1 Variance (org.nd4j.linalg.api.ops.impl.accum.Variance)1 If (org.nd4j.linalg.api.ops.impl.controlflow.If)1 While (org.nd4j.linalg.api.ops.impl.controlflow.While)1