Search in sources :

Example 1 with SameDiff

use of org.nd4j.autodiff.samediff.SameDiff in project nd4j by deeplearning4j.

the class LossFunctions method nonZeroCount.

/**
 * Determine the number of weight entries that are non-zero, after broadcasting
 *
 * @param weights
 * @param labels
 * @return
 */
private static SDVariable nonZeroCount(SDVariable weights, SDVariable labels) {
    SameDiff sd = weights.getSameDiff();
    SDVariable present = sd.neq(weights, 0.0);
    SDVariable presentBroadcast = sd.zerosLike(labels).add(present);
    return sd.sum(presentBroadcast);
}
Also used : SDVariable(org.nd4j.autodiff.samediff.SDVariable) SameDiff(org.nd4j.autodiff.samediff.SameDiff)

Example 2 with SameDiff

use of org.nd4j.autodiff.samediff.SameDiff in project nd4j by deeplearning4j.

the class GradCheckUtil method checkGradients.

/**
 * @param function
 * @param epsilon
 * @param maxRelError
 * @param print
 * @param inputParameters
 * @return
 */
public static boolean checkGradients(SDVariable function, SDVariable wrt, double epsilon, double maxRelError, boolean print, Map<String, INDArray> inputParameters) {
    // Basic sanity checks on input:
    if (epsilon <= 0.0 || epsilon > 0.1)
        throw new IllegalArgumentException("Invalid epsilon: expect epsilon in range (0,0.1], usually 1e-4 or so");
    if (maxRelError <= 0.0 || maxRelError > 0.25)
        throw new IllegalArgumentException("Invalid maxRelativeError: " + maxRelError);
    DataBuffer.Type dataType = DataTypeUtil.getDtypeFromContext();
    if (dataType != DataBuffer.Type.DOUBLE) {
        throw new IllegalStateException("Cannot perform gradient check: Datatype is not set to double precision (" + "is: " + dataType + "). Double precision must be used for gradient checks. Set " + "DataTypeUtil.setDTypeForContext(DataBuffer.Type.DOUBLE); before using GradientCheckUtil");
    }
    /**
     * Need to pass in the exact gradient.
     * This is obtained from executing a subgraph
     * with just the gradient part to get the exact values.
     * You then run the comparison vs the approximation from that.
     *
     * To obtain the comparison/computing the values,  use the below routine
     */
    SameDiff sameDiff = function.getSameDiff();
    // get just the subgraph for the graph
    SameDiff opExec = SameDiff.create(sameDiff);
    INDArray[] eval = opExec.eval(inputParameters);
    int totalNFailures = 0;
    double maxError = 0.0;
    for (Map.Entry<String, INDArray> entry : inputParameters.entrySet()) {
        int nParams = entry.getValue().length();
        INDArray params = entry.getValue().dup();
        for (int i = 0; i < nParams; i++) {
            INDArray zeros = Nd4j.create(nParams);
            zeros.putScalar(i, epsilon / 2.0);
            // (w+epsilon): Do forward pass and score
            double origValue = params.getDouble(i);
            params.putScalar(i, origValue + epsilon);
            Map<String, INDArray> evalParams = new HashMap<>();
            for (Map.Entry<String, INDArray> entry2 : inputParameters.entrySet()) {
                if (!entry2.getKey().equals(entry.getKey())) {
                    evalParams.put(entry2.getKey(), entry2.getValue());
                } else {
                    evalParams.put(entry.getKey(), params);
                }
            }
            /**
             * Need to figure out how I want to extract
             * parameters for computing the delta..
             */
            INDArray[] plusParams = sameDiff.eval(evalParams);
            INDArray[] minusParams = sameDiff.eval(evalParams);
            /**
             * Difference between new params and old
             */
            INDArray[] newDifferences = new INDArray[minusParams.length];
            for (int j = 0; j < newDifferences.length; j++) {
                newDifferences[j] = plusParams[j].subi(minusParams[j]).divi(epsilon);
            }
            double diff = plusParams[plusParams.length - 1].sumNumber().doubleValue() - minusParams[minusParams.length - 1].sumNumber().doubleValue();
            double eps = diff / epsilon;
            double correctVal = eval[eval.length - 1].sumNumber().doubleValue();
            double gradDiff = Math.abs(correctVal - eps);
            if (gradDiff > maxRelError)
                totalNFailures++;
            if (print) {
                int nPass = nParams - totalNFailures;
                log.info("GradientCheckUtil.checkGradients(): " + nParams + " params checked, " + nPass + " passed, " + totalNFailures + " failed. Largest relative error = " + maxError);
            }
        }
    }
    return totalNFailures == 0;
}
Also used : SameDiff(org.nd4j.autodiff.samediff.SameDiff) INDArray(org.nd4j.linalg.api.ndarray.INDArray) DataBuffer(org.nd4j.linalg.api.buffer.DataBuffer)

Example 3 with SameDiff

use of org.nd4j.autodiff.samediff.SameDiff in project nd4j by deeplearning4j.

the class GradCheckTransforms method testDynamicStitch.

@Test
public void testDynamicStitch() {
    SameDiff sd = SameDiff.create();
    INDArray ia = Nd4j.create(new float[] { 5, 1, 3 }, new int[] { 1, 3 });
    INDArray ib = Nd4j.create(new float[] { 7, 2, 4 }, new int[] { 1, 3 });
    INDArray indexA = Nd4j.create(new float[] { 0, 1, 4 }, new int[] { 1, 3 });
    INDArray indexB = Nd4j.create(new float[] { 2, 3, 5 }, new int[] { 1, 3 });
    INDArray expOut = Nd4j.create(new int[] { 1, 6 });
    DynamicCustomOp dynamicStitch = DynamicCustomOp.builder("dynamic_stitch").addInputs(indexA, indexB, ia, ib).addOutputs(expOut).build();
    Nd4j.getExecutioner().exec(dynamicStitch);
    SDVariable in1 = sd.var("in1", new int[] { 1, 3 });
    SDVariable in2 = sd.var("in2", new int[] { 1, 3 });
    SDVariable index1 = sd.var("index1", new int[] { 1, 3 });
    SDVariable index2 = sd.var("index2", new int[] { 1, 3 });
    sd.associateArrayWithVariable(ia, in1);
    sd.associateArrayWithVariable(ib, in2);
    sd.associateArrayWithVariable(indexA, index1);
    sd.associateArrayWithVariable(indexB, index2);
    SDVariable t = sd.dynamicStitch(new SDVariable[] { index1, index2 }, new SDVariable[] { in1, in2 });
    SDVariable loss = sd.mean("loss", t);
    sd.exec();
    INDArray out = t.getArr();
    if (!expOut.equals(out)) {
        log.error("forward failed");
    }
}
Also used : SDVariable(org.nd4j.autodiff.samediff.SDVariable) INDArray(org.nd4j.linalg.api.ndarray.INDArray) SameDiff(org.nd4j.autodiff.samediff.SameDiff) DynamicCustomOp(org.nd4j.linalg.api.ops.DynamicCustomOp) Test(org.junit.Test)

Example 4 with SameDiff

use of org.nd4j.autodiff.samediff.SameDiff in project nd4j by deeplearning4j.

the class GradCheckTransforms method testDynamicPartition.

@Test
public void testDynamicPartition() {
    SameDiff sd = SameDiff.create();
    INDArray ia = Nd4j.create(new float[] { 4, 3, 5, 7, 8, 0 }, new int[] { 1, 6 });
    INDArray partitions = Nd4j.create(new float[] { 1, 0, 1, 0, 0, 1 });
    int numPartitions = 2;
    SDVariable in = sd.var("in", new int[] { 1, 6 });
    SDVariable sdPartitions = sd.var("partitions", new int[] { 1, 6 });
    INDArray expOut1 = Nd4j.create(new int[] { 1, 3 });
    INDArray expOut2 = Nd4j.create(new int[] { 1, 3 });
    INDArray[] expOut = new INDArray[] { expOut1, expOut2 };
    DynamicCustomOp dynamicPartition = DynamicCustomOp.builder("dynamic_partition").addInputs(ia, partitions).addIntegerArguments(numPartitions).addOutputs(expOut1, expOut2).build();
    Nd4j.getExecutioner().exec(dynamicPartition);
    SDVariable[] parts = sd.dynamicPartition(in, sdPartitions, numPartitions);
    // merge the output partitions together again, to retrieve a single
    // tensor and finally a scalar.
    SDVariable t = sd.mergeAdd(parts);
    SDVariable loss = sd.mean("loss", t);
    sd.associateArrayWithVariable(ia, in);
    sd.exec();
    INDArray[] out = new INDArray[numPartitions];
    for (int i = 0; i < parts.length; i++) {
        out[i] = parts[i].getArr();
    }
    if (!expOut.equals(out)) {
        log.error("forward failed");
    }
}
Also used : SDVariable(org.nd4j.autodiff.samediff.SDVariable) INDArray(org.nd4j.linalg.api.ndarray.INDArray) SameDiff(org.nd4j.autodiff.samediff.SameDiff) DynamicCustomOp(org.nd4j.linalg.api.ops.DynamicCustomOp) Test(org.junit.Test)

Example 5 with SameDiff

use of org.nd4j.autodiff.samediff.SameDiff in project nd4j by deeplearning4j.

the class GradCheckTransforms method testTransforms.

@Test
public void testTransforms() {
    // Test transforms (non-pairwise)
    Nd4j.getRandom().setSeed(12345);
    // TODO: to speed up integration of new ops for TF import, we sometimes skip "doDiff" implementations.
    // get back to those after release
    List<String> allSkipped = new ArrayList<>();
    List<String> allFailed = new ArrayList<>();
    for (int i = 0; i < 62; i++) {
        boolean skipBackward = false;
        SameDiff sd = SameDiff.create();
        int nOut = 4;
        int minibatch = 5;
        SDVariable in = sd.var("in", new int[] { -1, nOut });
        INDArray ia = Nd4j.randn(minibatch, nOut);
        int dim;
        SDVariable t;
        INDArray expOut;
        switch(i) {
            case 0:
                t = in.add(5.0);
                expOut = ia.add(5.0);
                break;
            case 1:
                t = in.sub(5.0);
                expOut = ia.sub(5.0);
                break;
            case 2:
                t = in.mul(2.5);
                expOut = ia.mul(2.5);
                break;
            case 3:
                t = in.div(4.0);
                expOut = ia.div(4.0);
                break;
            case 4:
                t = in.rsub(5.0);
                expOut = ia.rsub(5.0);
                break;
            case 5:
                t = in.rdiv(1.0);
                expOut = ia.rdiv(1.0);
                break;
            case 6:
                t = sd.pow(in, 2.5);
                ia = Nd4j.rand(minibatch, nOut);
                expOut = Transforms.pow(ia, 2.5, true);
                break;
            case 7:
                t = sd.sigmoid(in);
                ia = Nd4j.rand(minibatch, nOut).muli(2).subi(1.0);
                expOut = Transforms.sigmoid(ia, true);
                break;
            case 8:
                t = sd.tanh(in);
                ia = Nd4j.rand(minibatch, nOut).muli(2).subi(1.0);
                expOut = Transforms.tanh(ia, true);
                break;
            case 9:
                t = sd.tan(in);
                ia = Nd4j.rand(minibatch, nOut);
                expOut = Nd4j.getExecutioner().execAndReturn(new Tan(ia.dup()));
                break;
            case 10:
                t = sd.cos(in);
                expOut = Transforms.cos(ia, true);
                break;
            case 11:
                t = sd.sin(in);
                expOut = Transforms.sin(ia, true);
                break;
            case 12:
                t = sd.softplus(in);
                expOut = Transforms.softPlus(ia, true);
                break;
            case 13:
                t = sd.log(in);
                ia = Nd4j.rand(minibatch, nOut);
                expOut = Transforms.log(ia, true);
                skipBackward = true;
                break;
            case 14:
                t = sd.neg(in);
                expOut = ia.neg();
                break;
            case 15:
                t = sd.acos(in);
                ia = Nd4j.rand(minibatch, nOut).muli(1.8).subi(0.9);
                expOut = Transforms.acos(ia, true);
                break;
            case 16:
                t = sd.acosh(in);
                // Only defined for x >= 1
                ia = Nd4j.rand(minibatch, nOut).addi(1.01);
                expOut = Nd4j.getExecutioner().execAndReturn(new ACosh(ia.dup()));
                skipBackward = true;
                break;
            case 17:
                t = sd.asin(in);
                ia = Nd4j.rand(minibatch, nOut).muli(1.8).subi(0.9);
                expOut = Transforms.asin(ia, true);
                break;
            case 18:
                t = sd.atan(in);
                ia = Nd4j.rand(minibatch, nOut).muli(4).subi(2);
                expOut = Transforms.atan(ia, true);
                break;
            case 19:
                t = sd.atanh(in);
                ia = Nd4j.rand(minibatch, nOut).muli(1.8).subi(0.9);
                expOut = Transforms.atanh(ia, true);
                break;
            case 20:
                t = sd.cosh(in);
                expOut = Transforms.cosh(ia, true);
                break;
            case 21:
                t = sd.cube(in);
                expOut = Transforms.pow(ia, 3.0, true);
                break;
            case 22:
                t = sd.elu(in);
                expOut = Transforms.elu(ia, true);
                break;
            case 23:
                // TODO SHOULDN'T THIS HAVE A DIMENSION ARG???
                t = sd.softmax(in);
                ia = Nd4j.rand(minibatch, nOut);
                expOut = Nd4j.getExecutioner().execAndReturn(new OldSoftMax(ia.dup()));
                break;
            case 24:
                t = sd.sqrt(in);
                ia = Nd4j.rand(minibatch, nOut);
                expOut = Transforms.sqrt(ia, true);
                break;
            case 25:
                t = sd.square(in);
                expOut = Transforms.pow(ia, 2.0, true);
                break;
            case 26:
                t = sd.transpose(in);
                expOut = ia.transpose().dup();
                break;
            case 27:
                t = sd.abs(in);
                expOut = Transforms.abs(ia, true);
                break;
            case 28:
                t = sd.sinh(in);
                expOut = Transforms.sinh(ia, true);
                break;
            case 29:
                t = sd.asinh(in);
                expOut = Nd4j.getExecutioner().execAndReturn(new ASinh(ia.dup()));
                skipBackward = true;
                break;
            case 30:
                t = sd.exp(in);
                expOut = Transforms.exp(ia, true);
                break;
            case 31:
                t = sd.floor(in);
                expOut = Transforms.floor(ia, true);
                break;
            case 32:
                t = sd.relu(in, 0.0);
                ia = Nd4j.rand(minibatch, nOut);
                expOut = Transforms.relu(ia, true);
                break;
            case 33:
                t = sd.hardTanh(in);
                ia = Nd4j.rand(minibatch, nOut).muli(2).subi(1.0);
                expOut = Transforms.hardTanh(ia, true);
                break;
            case 34:
                t = sd.logSigmoid(in);
                expOut = Nd4j.getExecutioner().execAndReturn(new LogSigmoid(ia.dup()));
                break;
            case 35:
                t = sd.swish(in);
                expOut = Nd4j.getExecutioner().execAndReturn(new Swish(ia.dup()));
                break;
            case 36:
                t = sd.sign(in);
                expOut = Transforms.sign(ia, true);
                break;
            case 37:
                t = sd.softsign(in);
                expOut = Transforms.softsign(ia, true);
                break;
            case 38:
                t = sd.leakyRelu(in, 0.0);
                ia = Nd4j.rand(minibatch, nOut);
                expOut = Transforms.leakyRelu(ia, true);
                break;
            case 39:
                // TODO DIMENSION ARG???
                // TODO fix me
                t = sd.logSoftmax(in);
                ia = Nd4j.rand(minibatch, nOut);
                expOut = Transforms.log(Transforms.softmax(ia, true));
                skipBackward = true;
                break;
            case 40:
                t = sd.selu(in);
                expOut = Nd4j.getExecutioner().execAndReturn(new SELU(ia.dup()));
                break;
            case 41:
                t = sd.gt(in, 1.0);
                expOut = ia.gt(1.0);
                break;
            case 42:
                t = sd.gte(in, 1.0);
                expOut = ia.gte(1.0);
                break;
            case 43:
                t = sd.lt(in, 1.0);
                expOut = ia.lt(1.0);
                break;
            case 44:
                t = sd.lte(in, 1.0);
                expOut = ia.lte(1.0);
                break;
            case 45:
                t = sd.eq(in, 2.0);
                ia = Nd4j.linspace(1, minibatch * nOut, minibatch * nOut).reshape('c', minibatch, nOut);
                expOut = ia.eq(2.0);
                break;
            case 46:
                t = sd.neq(in, 2.0);
                ia = Nd4j.linspace(1, minibatch * nOut, minibatch * nOut).reshape('c', minibatch, nOut);
                expOut = ia.neq(2.0);
                break;
            case 47:
                t = sd.ceil(in);
                expOut = Transforms.ceil(ia, true);
                break;
            case 48:
                ia = Nd4j.randn(ia.shape()).muli(2);
                t = sd.clipByValue(in, -3, 2);
                expOut = ia.dup();
                BooleanIndexing.replaceWhere(expOut, -3, Conditions.lessThan(-3));
                BooleanIndexing.replaceWhere(expOut, 2, Conditions.greaterThan(2));
                break;
            case 49:
                // Clip by norm, dimension 0, some below threshold, some above
                double clip = 2.0;
                ia = Nd4j.rand(ia.shape());
                // Exactly at threshold...
                ia.muliRowVector(ia.norm2(0).rdiv(clip));
                System.out.println(ia.norm2(0));
                ia.muliRowVector(Nd4j.linspace(0.9, 1.1, ia.size(1)));
                System.out.println(ia.norm2(0));
                expOut = Nd4j.create(ia.shape());
                for (int j = 0; j < ia.columns(); j++) {
                    INDArray origCol = ia.getColumn(j);
                    if (origCol.norm2Number().doubleValue() < clip) {
                        expOut.putColumn(j, origCol);
                    } else {
                        expOut.putColumn(j, origCol.mul(clip / origCol.norm2Number().doubleValue()));
                    }
                }
                t = sd.clipByNorm(in, clip);
                break;
            // TODO clip by norm along other dimensions
            case 50:
                dim = 1;
                t = sd.reverse(in, dim);
                expOut = Nd4j.create(ia.shape());
                DynamicCustomOp reverse = DynamicCustomOp.builder("reverse").addIntegerArguments(dim).addInputs(ia).addOutputs(expOut).build();
                Nd4j.getExecutioner().exec(reverse);
                break;
            case 51:
                dim = 0;
                boolean exclusive = false;
                boolean reverseBool = false;
                t = sd.cumsum(in, exclusive, reverseBool, dim);
                expOut = Nd4j.create(ia.shape());
                DynamicCustomOp cumsum = DynamicCustomOp.builder("cumsum").addIntegerArguments((exclusive) ? 1 : 0, (reverseBool) ? 1 : 0, dim).addInputs(ia).addOutputs(expOut).build();
                Nd4j.getExecutioner().exec(cumsum);
                break;
            case 52:
                dim = 0;
                boolean ex = false;
                boolean revBool = false;
                t = sd.cumsum(in, ex, revBool, dim);
                expOut = Nd4j.create(ia.shape());
                DynamicCustomOp cumprod = DynamicCustomOp.builder("cumprod").addIntegerArguments((ex) ? 1 : 0, (revBool) ? 1 : 0, dim).addInputs(ia).addOutputs(expOut).build();
                Nd4j.getExecutioner().exec(cumprod);
                break;
            case 53:
                ia = Nd4j.create(new float[] { 4, 2 });
                in = sd.var("in", new int[] { 1, 2 });
                expOut = Nd4j.create(new int[] { 2, 2 });
                DynamicCustomOp op = DynamicCustomOp.builder("diag").addInputs(ia).addOutputs(expOut).build();
                Nd4j.getExecutioner().exec(op);
                t = sd.diag(in);
                break;
            case 54:
                expOut = Nd4j.createUninitialized(ia.shape(), ia.ordering());
                Nd4j.getExecutioner().exec(new Erf(ia, expOut));
                t = sd.erf(in);
                skipBackward = true;
                break;
            case 55:
                expOut = Nd4j.createUninitialized(ia.shape(), ia.ordering());
                Nd4j.getExecutioner().exec(new Erfc(ia, expOut));
                t = sd.erfc(in);
                skipBackward = true;
                break;
            case 56:
                t = sd.expm1(in);
                expOut = Transforms.expm1(ia, true);
                break;
            case 57:
                t = sd.log1p(in);
                ia = Nd4j.rand(minibatch, nOut);
                expOut = Transforms.log1p(ia, true);
                skipBackward = true;
                break;
            case 58:
                t = sd.round(in);
                expOut = Transforms.round(ia, true);
                skipBackward = true;
                break;
            case 59:
                ia = Nd4j.create(new float[] { 4, 2 });
                in = sd.var("in", new int[] { 1, 2 });
                t = sd.rsqrt(in);
                expOut = Nd4j.create(ia.shape(), ia.ordering());
                Nd4j.getExecutioner().exec(new RSqrt(ia, expOut));
                skipBackward = true;
                break;
            case 60:
                t = sd.relu6(in, 0);
                ia = Nd4j.rand(minibatch, nOut);
                expOut = Transforms.relu6(ia);
                skipBackward = true;
                break;
            case 61:
                ia = Nd4j.create(new float[] { 2, 2 });
                in = sd.var("in", new int[] { 1, 2 });
                sd.associateArrayWithVariable(ia, in);
                double value = 42;
                expOut = Nd4j.create(new int[] { 2, 2 });
                DynamicCustomOp fillOp = DynamicCustomOp.builder("fill").addInputs(ia).addFloatingPointArguments(value).addOutputs(expOut).build();
                Nd4j.getExecutioner().exec(fillOp);
                skipBackward = true;
                t = sd.fill(in, value);
                break;
            default:
                throw new RuntimeException();
        }
        DifferentialFunction[] funcs = sd.functions();
        String name = funcs[0].opName();
        String msg = "test: " + i + " - " + name;
        log.info("*** Starting test: " + msg);
        SDVariable loss = sd.mean("loss", t);
        sd.associateArrayWithVariable(ia, in);
        sd.exec();
        INDArray out = t.getArr();
        out.shape();
        if (!expOut.equals(out)) {
            allFailed.add(msg + " - FAILED ON FORWARD");
            continue;
        }
        boolean ok;
        if (skipBackward) {
            ok = true;
            msg += " - SKIPPED";
            allSkipped.add(msg);
        } else {
            try {
                ok = GradCheckUtil.checkGradients(sd);
            } catch (Exception e) {
                e.printStackTrace();
                msg += " - EXCEPTION";
                ok = false;
            }
        }
        assertTrue(msg, ok);
        if (!ok) {
            allFailed.add(msg);
        }
    }
    if (allSkipped.size() > 0) {
        log.info("All backward skipped transforms: " + allSkipped);
        log.info(allSkipped.size() + " backward passes were skipped.");
    }
    if (allFailed.size() > 0) {
        log.error("All failed transforms: " + allFailed);
        fail(allFailed.size() + " transforms failed");
    }
}
Also used : SameDiff(org.nd4j.autodiff.samediff.SameDiff) ArrayList(java.util.ArrayList) SDVariable(org.nd4j.autodiff.samediff.SDVariable) INDArray(org.nd4j.linalg.api.ndarray.INDArray) DifferentialFunction(org.nd4j.autodiff.functions.DifferentialFunction) DynamicCustomOp(org.nd4j.linalg.api.ops.DynamicCustomOp) Test(org.junit.Test)

Aggregations

SameDiff (org.nd4j.autodiff.samediff.SameDiff)50 Test (org.junit.Test)42 SDVariable (org.nd4j.autodiff.samediff.SDVariable)41 INDArray (org.nd4j.linalg.api.ndarray.INDArray)37 ArrayList (java.util.ArrayList)10 DynamicCustomOp (org.nd4j.linalg.api.ops.DynamicCustomOp)10 Ignore (org.junit.Ignore)7 ClassPathResource (org.nd4j.linalg.io.ClassPathResource)6 lombok.val (lombok.val)4 LossFunctions (org.nd4j.autodiff.loss.LossFunctions)4 LossInfo (org.nd4j.autodiff.loss.LossInfo)4 BernoulliDistribution (org.nd4j.linalg.api.ops.random.impl.BernoulliDistribution)4 DifferentialFunction (org.nd4j.autodiff.functions.DifferentialFunction)2 Triple (org.nd4j.linalg.primitives.Triple)2 ZeroInitScheme (org.nd4j.weightinit.impl.ZeroInitScheme)2 DataOutputStream (java.io.DataOutputStream)1 FileOutputStream (java.io.FileOutputStream)1 ByteBuffer (java.nio.ByteBuffer)1 HashMap (java.util.HashMap)1 LinkedHashMap (java.util.LinkedHashMap)1