use of org.nd4j.autodiff.functions.DifferentialFunction in project nd4j by deeplearning4j.
the class GradCheckUtil method checkGradients.
public static boolean checkGradients(SameDiff sd, double eps, double maxRelError, double minAbsError, boolean print, boolean exitOnFirstFailure) {
// Check data type:
if (Nd4j.dataType() != DataBuffer.Type.DOUBLE) {
throw new IllegalStateException("Data type must be set to double");
}
Set<String> fnOutputs = new HashSet<>();
for (DifferentialFunction f : sd.functions()) {
for (SDVariable s : f.outputVariables()) {
fnOutputs.add(s.getVarName());
}
}
// Check that all *input* SDVariables have arrays associated with them
for (SDVariable s : sd.variables()) {
if (fnOutputs.contains(s.getVarName())) {
// This is not an input to the graph
continue;
}
if (s.getArr() == null) {
throw new IllegalStateException("Variable \"" + s.getVarName() + "\" does not have array associated with it");
}
}
// Do forward pass, check that output is a scalar:
INDArray out = sd.execAndEndResult();
if (out.length() != 1) {
throw new IllegalStateException("Output variable is not a scalar - has shape " + Arrays.toString(out.shape()));
}
// TODO also check that all inputs are non-zero (otherwise: consider out = sum(x * y) with all x and y being 0
// in this case, gradients of x and y are all 0 too
sd.execBackwards();
Map<String, INDArray> grad = new HashMap<>();
for (SDVariable v : sd.variables()) {
if (fnOutputs.contains(v.getVarName())) {
// This is not an input to the graph
continue;
}
SDVariable g = sd.grad(v.getVarName());
if (g == null) {
throw new IllegalStateException("Null gradient variable for \"" + v.getVarName() + "\"");
}
INDArray ga = g.getArr();
if (ga == null) {
throw new IllegalStateException("Null gradient array encountered for variable: " + v.getVarName());
}
if (!Arrays.equals(v.getArr().shape(), g.getArr().shape())) {
throw new IllegalStateException("Gradient shape does not match variable shape for variable \"" + v.getVarName() + "\": shape " + Arrays.toString(v.getArr().shape()) + " vs. gradient shape " + Arrays.toString(ga.shape()));
}
grad.put(v.getVarName(), ga.dup());
}
// Validate gradients for each variable:
int totalNFailures = 0;
int totalCount = 0;
double maxError = 0.0;
for (SDVariable s : sd.variables()) {
if (fnOutputs.contains(s.getVarName())) {
// This is not an input to the graph
continue;
}
String name = s.getVarName();
INDArray a = s.getArr();
int n = a.length();
if (print) {
log.info("Starting test for variable \"{}\" with {} values", s.getVarName(), n);
}
NdIndexIterator iter = new NdIndexIterator('c', a.shape());
int i = 0;
while (iter.hasNext()) {
int[] idx = iter.next();
String strIdx = null;
if (print) {
strIdx = Arrays.toString(idx).replaceAll(" ", "");
}
totalCount++;
double orig = a.getDouble(idx);
a.putScalar(idx, orig + eps);
double scorePlus = sd.execAndEndResult().getDouble(0);
a.putScalar(idx, orig - eps);
double scoreMinus = sd.execAndEndResult().getDouble(0);
a.putScalar(idx, orig);
double numericalGrad = (scorePlus - scoreMinus) / (2 * eps);
double analyticGrad = grad.get(s.getVarName()).getDouble(idx);
if (Double.isInfinite(numericalGrad) || Double.isNaN(numericalGrad)) {
throw new IllegalStateException("Numerical gradient was " + numericalGrad + " for variable \"" + name + "\", parameter " + i + " of " + n + " (position: " + strIdx + ")");
}
if (Double.isInfinite(analyticGrad) || Double.isNaN(analyticGrad)) {
throw new IllegalStateException("Analytic (SameDiff) gradient was " + analyticGrad + " for variable \"" + name + "\", parameter " + i + " of " + n + " (position: " + strIdx + ")");
}
double relError;
if (numericalGrad == 0.0 && analyticGrad == 0.0) {
relError = 0.0;
} else {
relError = Math.abs(analyticGrad - numericalGrad) / (Math.abs(Math.abs(analyticGrad) + Math.abs(numericalGrad)));
}
if (relError > maxError)
maxError = relError;
if (relError > maxRelError || Double.isNaN(relError)) {
double absError = Math.abs(analyticGrad - numericalGrad);
if (absError < minAbsError) {
if (print) {
log.info("Param " + i + " (" + name + strIdx + ") passed: grad= " + analyticGrad + ", numericalGrad= " + numericalGrad + ", relError= " + relError + "; absolute error = " + absError + " < minAbsoluteError = " + minAbsError);
}
} else {
if (print)
log.info("Param " + i + " (" + name + strIdx + ") FAILED: grad= " + analyticGrad + ", numericalGrad= " + numericalGrad + ", relError= " + relError + ", absError=" + absError + ", scorePlus=" + scorePlus + ", scoreMinus= " + scoreMinus);
if (exitOnFirstFailure)
return false;
totalNFailures++;
}
} else if (print) {
log.info("Param " + i + " (" + name + strIdx + ") passed: grad= " + analyticGrad + ", numericalGrad= " + numericalGrad + ", relError= " + relError);
}
i++;
}
}
if (print) {
int nPass = totalCount - totalNFailures;
log.info("GradCheckUtil.checkGradients(): " + totalCount + " params checked, " + nPass + " passed, " + totalNFailures + " failed. Largest relative error = " + maxError);
}
return totalNFailures == 0;
}
use of org.nd4j.autodiff.functions.DifferentialFunction in project nd4j by deeplearning4j.
the class SameDiff method execBackwardAndEndResult.
/**
* Exec a backwards operation
* and return the end result
*
* @return
*/
public INDArray execBackwardAndEndResult() {
List<DifferentialFunction> backwards = execBackwards().getRight();
DifferentialFunction df = backwards.get(backwards.size() - 1);
if (df instanceof Op) {
return ((Op) df).z();
} else if (df instanceof DynamicCustomOp) {
return ((DynamicCustomOp) df).getOutputArgument(0);
} else {
return null;
}
}
use of org.nd4j.autodiff.functions.DifferentialFunction in project nd4j by deeplearning4j.
the class SameDiff method invokeGraphOn.
/**
* @param sameDiff
* @return
*/
public SDVariable invokeGraphOn(SameDiff sameDiff) {
// map the new vertices on to the old ones
Map<Integer, Integer> thisVertexIdToNew = new HashMap<>();
int idx = 1;
for (val var : variables()) {
val clone = cloner.deepCloneDontCloneInstances(var, var.getSameDiff());
val newVar = sameDiff.var(clone);
if (var.getArr() != null) {
sameDiff.associateArrayWithVariable(var.getArr(), newVar);
}
thisVertexIdToNew.put(idx, idx);
clone.setSameDiff(sameDiff);
idx++;
}
val newFunctions = new LinkedHashMap<String, DifferentialFunction>();
for (DifferentialFunction function : functionInstancesById.values()) {
if (function instanceof SDVariable) {
continue;
}
DifferentialFunction clone = cloner.deepCloneDontCloneInstances(function, function.getSameDiff());
clone.setSameDiff(sameDiff);
clone.setOwnName(function.getOwnName());
if (sameDiff.functionExists(function.getOwnName()))
sameDiff.putFunctionForId(function.getOwnName(), function);
newFunctions.put(function.getOwnName(), clone);
val argsForFunction = function.args();
val outputsForFunction = function.outputVariables();
// note that these have the same variable names
sameDiff.addArgsFor(argsForFunction, clone);
sameDiff.addOutgoingFor(outputsForFunction, function);
for (val arg : clone.args()) {
arg.setSameDiff(sameDiff);
}
for (val output : clone.outputVariables()) {
output.setSameDiff(sameDiff);
}
sameDiff.functionInstancesById.put(function.getOwnName(), function);
}
for (val reverseArrayEntry : reverseArrayLookup.entrySet()) {
sameDiff.reverseArrayLookup.put(reverseArrayEntry.getKey(), sameDiff.getVariable(reverseArrayEntry.getValue().getVarName()));
}
return sameDiff.variables().get(sameDiff.variables().size() - 1);
}
Aggregations