Search in sources :

Example 6 with BernoulliDistribution

use of org.nd4j.linalg.api.ops.random.impl.BernoulliDistribution in project nd4j by deeplearning4j.

the class GradCheckLoss method testLossWeights4d.

@Test
public void testLossWeights4d() {
    String[] weightTypes = new String[] { "none", "per-example", "per-depth", "per-height", "per-width", "per-height-width", "per-depth-height", "per-depth-width", "per-example-depth", "per-example-height", "per-example-height-width", "per-all" };
    Nd4j.getRandom().setSeed(12345);
    // Assume NCHW format here
    int minibatch = 10;
    int depth = 4;
    int h = 5;
    int w = 6;
    for (String weightType : weightTypes) {
        for (boolean binary : new boolean[] { true, false }) {
            // Binary mask (like DL4J) or arbitrary weights?
            int[] weightShape;
            switch(weightType) {
                case "none":
                    weightShape = null;
                    break;
                case "per-example":
                    weightShape = new int[] { minibatch, 1, 1, 1 };
                    break;
                case "per-depth":
                    weightShape = new int[] { 1, depth, 1, 1 };
                    break;
                case "per-height":
                    weightShape = new int[] { 1, 1, h, 1 };
                    break;
                case "per-width":
                    weightShape = new int[] { 1, 1, 1, w };
                    break;
                case "per-height-width":
                    weightShape = new int[] { 1, 1, h, w };
                    break;
                case "per-depth-height":
                    weightShape = new int[] { 1, depth, h, 1 };
                    break;
                case "per-depth-width":
                    weightShape = new int[] { 1, depth, 1, w };
                    break;
                case "per-example-depth":
                    weightShape = new int[] { minibatch, depth, 1, 1 };
                    break;
                case "per-example-height":
                    weightShape = new int[] { minibatch, 1, h, 1 };
                    break;
                case "per-example-height-width":
                    weightShape = new int[] { minibatch, 1, h, w };
                    break;
                case "per-all":
                    weightShape = new int[] { minibatch, depth, h, w };
                    break;
                default:
                    throw new RuntimeException("Unknown type: " + weightType);
            }
            INDArray weightArr = null;
            if (!"none".equals(weightType)) {
                if (binary) {
                    weightArr = Nd4j.getExecutioner().exec(new BernoulliDistribution(Nd4j.createUninitialized(weightShape), 0.5));
                } else {
                    weightArr = Nd4j.rand(weightShape).muli(2.0);
                }
            }
            for (LossFunctions.Reduction reduction : new LossFunctions.Reduction[] { LossFunctions.Reduction.MEAN_BY_COUNT, LossFunctions.Reduction.MEAN_BY_WEIGHT, LossFunctions.Reduction.SUM }) {
                for (String fn : new String[] { "mse", "l1", "l2", "mcxent" }) {
                    SameDiff sd = SameDiff.create();
                    SDVariable input = sd.var("in", new int[] { -1, depth, -1, -1 });
                    SDVariable labels = sd.var("labels", new int[] { -1, depth, -1, -1 });
                    SDVariable weight = null;
                    if (!"none".equals(weightType)) {
                        weight = sd.var("weights", weightArr);
                    }
                    INDArray inputArr = Nd4j.randn(new int[] { minibatch, depth, h, w }).muli(10);
                    INDArray labelsArr = Nd4j.randn(new int[] { minibatch, depth, h, w }).muli(10);
                    LossInfo lossInfo;
                    switch(fn) {
                        case "mse":
                            lossInfo = LossFunctions.mse("out", input, labels, weight, reduction, 1, 2, 3);
                            break;
                        case "l1":
                            lossInfo = LossFunctions.l1("out", input, labels, weight, reduction, 1, 2, 3);
                            // L1 = sum abs error
                            break;
                        case "l2":
                            lossInfo = LossFunctions.l2("out", input, labels, weight, reduction, 1, 2, 3);
                            // L2 = sum squared error
                            break;
                        case "mcxent":
                            lossInfo = LossFunctions.mcxent("out", input, labels, weight, reduction, 1, 2, 3);
                            // mcxent = sum label * log(prob)
                            break;
                        default:
                            throw new RuntimeException();
                    }
                    String msg = "lossFn=" + fn + ", reduction=" + reduction + ", weightType=" + weightType + ", binaryWeight=" + binary;
                    log.info("*** Starting test: " + msg);
                    sd.associateArrayWithVariable(inputArr, input);
                    sd.associateArrayWithVariable(labelsArr, labels);
                    if (weight != null) {
                        sd.associateArrayWithVariable(weightArr, weight);
                    }
                    INDArray out = sd.execAndEndResult();
                    assertEquals(1, out.length());
                    boolean ok = GradCheckUtil.checkGradients(sd);
                    assertTrue(msg, ok);
                }
            }
        }
    }
}
Also used : SameDiff(org.nd4j.autodiff.samediff.SameDiff) LossFunctions(org.nd4j.autodiff.loss.LossFunctions) LossInfo(org.nd4j.autodiff.loss.LossInfo) SDVariable(org.nd4j.autodiff.samediff.SDVariable) INDArray(org.nd4j.linalg.api.ndarray.INDArray) BernoulliDistribution(org.nd4j.linalg.api.ops.random.impl.BernoulliDistribution) Test(org.junit.Test)

Example 7 with BernoulliDistribution

use of org.nd4j.linalg.api.ops.random.impl.BernoulliDistribution in project deeplearning4j by deeplearning4j.

the class VaeGradientCheckTests method testVaePretrainReconstructionDistributions.

@Test
public void testVaePretrainReconstructionDistributions() {
    int inOutSize = 6;
    ReconstructionDistribution[] reconstructionDistributions = new ReconstructionDistribution[] { new GaussianReconstructionDistribution(Activation.IDENTITY), new GaussianReconstructionDistribution(Activation.TANH), new BernoulliReconstructionDistribution(Activation.SIGMOID), new CompositeReconstructionDistribution.Builder().addDistribution(2, new GaussianReconstructionDistribution(Activation.IDENTITY)).addDistribution(2, new BernoulliReconstructionDistribution()).addDistribution(2, new GaussianReconstructionDistribution(Activation.TANH)).build(), new ExponentialReconstructionDistribution("identity"), new ExponentialReconstructionDistribution("tanh"), new LossFunctionWrapper(new ActivationTanH(), new LossMSE()), new LossFunctionWrapper(new ActivationIdentity(), new LossMAE()) };
    Nd4j.getRandom().setSeed(12345);
    for (int minibatch : new int[] { 1, 5 }) {
        for (int i = 0; i < reconstructionDistributions.length; i++) {
            INDArray data;
            switch(i) {
                //Gaussian + identity
                case 0:
                case //Gaussian + tanh
                1:
                    data = Nd4j.rand(minibatch, inOutSize);
                    break;
                case //Bernoulli
                2:
                    data = Nd4j.create(minibatch, inOutSize);
                    Nd4j.getExecutioner().exec(new BernoulliDistribution(data, 0.5), Nd4j.getRandom());
                    break;
                case //Composite
                3:
                    data = Nd4j.create(minibatch, inOutSize);
                    data.get(NDArrayIndex.all(), NDArrayIndex.interval(0, 2)).assign(Nd4j.rand(minibatch, 2));
                    Nd4j.getExecutioner().exec(new BernoulliDistribution(data.get(NDArrayIndex.all(), NDArrayIndex.interval(2, 4)), 0.5), Nd4j.getRandom());
                    data.get(NDArrayIndex.all(), NDArrayIndex.interval(4, 6)).assign(Nd4j.rand(minibatch, 2));
                    break;
                case 4:
                case 5:
                    data = Nd4j.rand(minibatch, inOutSize);
                    break;
                case 6:
                case 7:
                    data = Nd4j.randn(minibatch, inOutSize);
                    break;
                default:
                    throw new RuntimeException();
            }
            MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().regularization(true).l2(0.2).l1(0.3).optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).learningRate(1.0).seed(12345L).weightInit(WeightInit.DISTRIBUTION).dist(new NormalDistribution(0, 1)).list().layer(0, new VariationalAutoencoder.Builder().nIn(inOutSize).nOut(3).encoderLayerSizes(5).decoderLayerSizes(6).pzxActivationFunction(Activation.TANH).reconstructionDistribution(reconstructionDistributions[i]).activation(Activation.TANH).updater(Updater.SGD).build()).pretrain(true).backprop(false).build();
            MultiLayerNetwork mln = new MultiLayerNetwork(conf);
            mln.init();
            mln.initGradientsView();
            org.deeplearning4j.nn.api.Layer layer = mln.getLayer(0);
            String msg = "testVaePretrainReconstructionDistributions() - " + reconstructionDistributions[i];
            if (PRINT_RESULTS) {
                System.out.println(msg);
                for (int j = 0; j < mln.getnLayers(); j++) System.out.println("Layer " + j + " # params: " + mln.getLayer(j).numParams());
            }
            boolean gradOK = GradientCheckUtil.checkGradientsPretrainLayer(layer, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR, DEFAULT_MIN_ABS_ERROR, PRINT_RESULTS, RETURN_ON_FIRST_FAILURE, data, 12345);
            assertTrue(msg, gradOK);
        }
    }
}
Also used : LossMAE(org.nd4j.linalg.lossfunctions.impl.LossMAE) MultiLayerConfiguration(org.deeplearning4j.nn.conf.MultiLayerConfiguration) ActivationIdentity(org.nd4j.linalg.activations.impl.ActivationIdentity) org.deeplearning4j.nn.api(org.deeplearning4j.nn.api) MultiLayerNetwork(org.deeplearning4j.nn.multilayer.MultiLayerNetwork) NeuralNetConfiguration(org.deeplearning4j.nn.conf.NeuralNetConfiguration) LossMSE(org.nd4j.linalg.lossfunctions.impl.LossMSE) INDArray(org.nd4j.linalg.api.ndarray.INDArray) BernoulliDistribution(org.nd4j.linalg.api.ops.random.impl.BernoulliDistribution) NormalDistribution(org.deeplearning4j.nn.conf.distribution.NormalDistribution) ActivationTanH(org.nd4j.linalg.activations.impl.ActivationTanH) Test(org.junit.Test)

Example 8 with BernoulliDistribution

use of org.nd4j.linalg.api.ops.random.impl.BernoulliDistribution in project nd4j by deeplearning4j.

the class GradCheckTransforms method testPairwiseTransforms.

@Test
public void testPairwiseTransforms() {
    /*
        add, sub, mul, div, rsub, rdiv
        eq, neq, gt, lt, gte, lte, or, and, xor
        min, max
        mmul
        tensormmul
         */
    // Test transforms (pairwise)
    Nd4j.getRandom().setSeed(12345);
    List<String> allSkipped = new ArrayList<>();
    List<String> allFailed = new ArrayList<>();
    for (int i = 0; i < 23; i++) {
        boolean skipBackward = false;
        SameDiff sd = SameDiff.create();
        int nOut = 4;
        int minibatch = 5;
        SDVariable in1 = sd.var("in1", new int[] { -1, nOut });
        SDVariable in2 = sd.var("in2", new int[] { -1, nOut });
        INDArray ia = Nd4j.randn(minibatch, nOut);
        INDArray ib = Nd4j.randn(minibatch, nOut);
        SDVariable t;
        INDArray expOut;
        switch(i) {
            case 0:
                t = in1.add(in2);
                expOut = ia.add(ib);
                break;
            case 1:
                t = in1.sub(in2);
                expOut = ia.sub(ib);
                break;
            case 2:
                t = in1.mul(in2);
                expOut = ia.mul(ib);
                break;
            case 3:
                // break;
                continue;
            case 4:
                t = in1.rsub(in2);
                expOut = ia.rsub(ib);
                break;
            case 5:
                t = in1.rdiv(in2);
                expOut = ia.rdiv(ib);
                break;
            case 6:
                t = sd.eq(in1, in2);
                expOut = ia.eq(ib);
                break;
            case 7:
                t = sd.neq(in1, in2);
                expOut = ia.neq(ib);
                break;
            case 8:
                t = sd.gt(in1, in2);
                expOut = ia.gt(ib);
                break;
            case 9:
                t = sd.lt(in1, in2);
                expOut = ia.lt(ib);
                break;
            case 10:
                t = sd.gte(in1, in2);
                expOut = ia.dup();
                Nd4j.getExecutioner().exec(new GreaterThanOrEqual(new INDArray[] { ia, ib }, new INDArray[] { expOut }));
                break;
            case 11:
                t = sd.lte(in1, in2);
                expOut = ia.dup();
                Nd4j.getExecutioner().exec(new LessThanOrEqual(new INDArray[] { ia, ib }, new INDArray[] { expOut }));
                break;
            case 12:
                ia = Nd4j.getExecutioner().exec(new BernoulliDistribution(ia, 0.5));
                ib = Nd4j.getExecutioner().exec(new BernoulliDistribution(ib, 0.5));
                t = sd.or(in1, in2);
                expOut = Transforms.or(ia, ib);
                break;
            case 13:
                ib = Nd4j.randn(nOut, nOut);
                t = sd.mmul(in1, in2);
                expOut = ia.mmul(ib);
                break;
            case 14:
                t = sd.max(in1, in2);
                expOut = Nd4j.getExecutioner().execAndReturn(new OldMax(ia, ib, ia.dup(), ia.length()));
                break;
            case 15:
                t = sd.min(in1, in2);
                expOut = Nd4j.getExecutioner().execAndReturn(new OldMin(ia, ib, ia.dup(), ia.length()));
                break;
            case 16:
                ia = Nd4j.getExecutioner().exec(new BernoulliDistribution(ia, 0.5));
                ib = Nd4j.getExecutioner().exec(new BernoulliDistribution(ib, 0.5));
                t = sd.and(in1, in2);
                expOut = Transforms.and(ia, ib);
                break;
            case 17:
                ia = Nd4j.getExecutioner().exec(new BernoulliDistribution(ia, 0.5));
                ib = Nd4j.getExecutioner().exec(new BernoulliDistribution(ib, 0.5));
                t = sd.xor(in1, in2);
                expOut = Transforms.xor(ia, ib);
                break;
            case 18:
                t = sd.assign(in1, in2);
                expOut = ib;
                break;
            case 19:
                t = sd.atan2(in1, in2);
                // Note: y,x order for samediff; x,y order for transforms
                expOut = Transforms.atan2(ib, ia);
                skipBackward = true;
                break;
            case 20:
                t = sd.mergeAdd(in1, in2, in2);
                expOut = ia.add(ib).add(ib);
                break;
            case 21:
                ia = Nd4j.create(new float[] { 2, 4 });
                ib = Nd4j.create(new float[] { 42, 2 });
                in1 = sd.var("in1", new int[] { 1, 2 });
                in2 = sd.var("in2", new int[] { 1, 2 });
                t = in1.truncatedDiv(in2);
                expOut = Nd4j.create(ia.shape(), ia.ordering());
                Nd4j.getExecutioner().exec(new TruncateDivOp(ia, ib, expOut));
                skipBackward = true;
                break;
            case 22:
                t = in1.squaredDifference(in2);
                expOut = Nd4j.create(ia.shape(), ia.ordering());
                DynamicCustomOp squareDiff = DynamicCustomOp.builder("squaredsubtract").addInputs(ia, ib).addOutputs(expOut).build();
                Nd4j.getExecutioner().exec(squareDiff);
                skipBackward = true;
                break;
            default:
                throw new RuntimeException();
        }
        DifferentialFunction[] funcs = sd.functions();
        String name = funcs[0].opName();
        String msg = "test: " + i + " - " + name;
        log.info("*** Starting test: " + msg);
        SDVariable loss = sd.mean("loss", t);
        sd.associateArrayWithVariable(ia, in1);
        sd.associateArrayWithVariable(ib, in2);
        sd.exec();
        INDArray out = t.getArr();
        assertEquals(msg, expOut, out);
        boolean ok;
        if (skipBackward) {
            ok = true;
            msg += " - SKIPPED";
            allSkipped.add(msg);
        } else {
            try {
                ok = GradCheckUtil.checkGradients(sd);
            } catch (Exception e) {
                e.printStackTrace();
                msg += " - EXCEPTION";
                ok = false;
            }
        }
        if (!ok) {
            allFailed.add(msg);
        }
    }
    if (allSkipped.size() > 0) {
        log.info("All backward skipped transforms: " + allSkipped);
        log.info(allSkipped.size() + " backward passes were skipped.");
    }
    if (allFailed.size() > 0) {
        log.error("All failed transforms: " + allFailed);
        fail(allFailed.size() + " transforms failed");
    }
}
Also used : LessThanOrEqual(org.nd4j.linalg.api.ops.impl.transforms.comparison.LessThanOrEqual) OldMin(org.nd4j.linalg.api.ops.impl.transforms.comparison.OldMin) SameDiff(org.nd4j.autodiff.samediff.SameDiff) ArrayList(java.util.ArrayList) GreaterThanOrEqual(org.nd4j.linalg.api.ops.impl.transforms.comparison.GreaterThanOrEqual) SDVariable(org.nd4j.autodiff.samediff.SDVariable) OldMax(org.nd4j.linalg.api.ops.impl.transforms.comparison.OldMax) INDArray(org.nd4j.linalg.api.ndarray.INDArray) BernoulliDistribution(org.nd4j.linalg.api.ops.random.impl.BernoulliDistribution) DifferentialFunction(org.nd4j.autodiff.functions.DifferentialFunction) DynamicCustomOp(org.nd4j.linalg.api.ops.DynamicCustomOp) TruncateDivOp(org.nd4j.linalg.api.ops.impl.transforms.arithmetic.TruncateDivOp) Test(org.junit.Test)

Example 9 with BernoulliDistribution

use of org.nd4j.linalg.api.ops.random.impl.BernoulliDistribution in project nd4j by deeplearning4j.

the class GradCheckLoss method testLossWeights3d.

@Test
public void testLossWeights3d() {
    String[] weightTypes = new String[] { "none", "per-example", "per-output", "per-timestep", "per-example-output", "per-example-timestep", "per-output-timestep", "per-all" };
    Nd4j.getRandom().setSeed(12345);
    int nOut = 4;
    int minibatch = 10;
    int tsLength = 5;
    for (String weightType : weightTypes) {
        for (boolean binary : new boolean[] { true, false }) {
            // Binary mask (like DL4J) or arbitrary weights?
            int[] weightShape;
            switch(weightType) {
                case "none":
                    weightShape = null;
                    break;
                case "per-example":
                    weightShape = new int[] { minibatch, 1, 1 };
                    break;
                case "per-output":
                    weightShape = new int[] { 1, nOut, 1 };
                    break;
                case "per-timestep":
                    weightShape = new int[] { 1, 1, tsLength };
                    break;
                case "per-example-output":
                    weightShape = new int[] { minibatch, nOut, 1 };
                    break;
                case "per-example-timestep":
                    weightShape = new int[] { minibatch, 1, nOut };
                    break;
                case "per-output-timestep":
                    weightShape = new int[] { 1, nOut, tsLength };
                    break;
                case "per-all":
                    weightShape = new int[] { minibatch, nOut, tsLength };
                    break;
                default:
                    throw new RuntimeException("Unknown type: " + weightType);
            }
            INDArray weightArr = null;
            if (!"none".equals(weightType)) {
                if (binary) {
                    weightArr = Nd4j.getExecutioner().exec(new BernoulliDistribution(Nd4j.createUninitialized(weightShape), 0.5));
                } else {
                    weightArr = Nd4j.rand(weightShape).muli(2.0);
                }
            }
            for (LossFunctions.Reduction reduction : new LossFunctions.Reduction[] { LossFunctions.Reduction.MEAN_BY_COUNT, LossFunctions.Reduction.MEAN_BY_WEIGHT, LossFunctions.Reduction.SUM }) {
                for (String fn : new String[] { "mse", "l1", "l2", "mcxent" }) {
                    SameDiff sd = SameDiff.create();
                    SDVariable input = sd.var("in", new int[] { -1, nOut, -1 });
                    SDVariable labels = sd.var("labels", new int[] { -1, nOut, -1 });
                    SDVariable weight = null;
                    if (!"none".equals(weightType)) {
                        weight = sd.var("weights", weightArr);
                    }
                    INDArray inputArr = Nd4j.randn(new int[] { minibatch, nOut, tsLength }).muli(10);
                    INDArray labelsArr = Nd4j.randn(new int[] { minibatch, nOut, tsLength }).muli(10);
                    LossInfo lossInfo;
                    switch(fn) {
                        case "mse":
                            lossInfo = LossFunctions.mse("out", input, labels, weight, reduction, 1, 2);
                            break;
                        case "l1":
                            lossInfo = LossFunctions.l1("out", input, labels, weight, reduction, 1, 2);
                            // L1 = sum abs error
                            break;
                        case "l2":
                            lossInfo = LossFunctions.l2("out", input, labels, weight, reduction, 1, 2);
                            // L2 = sum squared error
                            break;
                        case "mcxent":
                            lossInfo = LossFunctions.mcxent("out", input, labels, weight, reduction, 1, 2);
                            // mcxent = sum label * log(prob)
                            break;
                        default:
                            throw new RuntimeException();
                    }
                    String msg = "lossFn=" + fn + ", reduction=" + reduction + ", weightType=" + weightType + ", binaryWeight=" + binary;
                    log.info("*** Starting test: " + msg);
                    sd.associateArrayWithVariable(inputArr, input);
                    sd.associateArrayWithVariable(labelsArr, labels);
                    if (weight != null) {
                        sd.associateArrayWithVariable(weightArr, weight);
                    }
                    INDArray out = sd.execAndEndResult();
                    assertEquals(1, out.length());
                    boolean ok = GradCheckUtil.checkGradients(sd);
                    assertTrue(msg, ok);
                }
            }
        }
    }
}
Also used : SameDiff(org.nd4j.autodiff.samediff.SameDiff) LossFunctions(org.nd4j.autodiff.loss.LossFunctions) LossInfo(org.nd4j.autodiff.loss.LossInfo) SDVariable(org.nd4j.autodiff.samediff.SDVariable) INDArray(org.nd4j.linalg.api.ndarray.INDArray) BernoulliDistribution(org.nd4j.linalg.api.ops.random.impl.BernoulliDistribution) Test(org.junit.Test)

Example 10 with BernoulliDistribution

use of org.nd4j.linalg.api.ops.random.impl.BernoulliDistribution in project nd4j by deeplearning4j.

the class BaseUnderSamplingPreProcessor method adjustMasks.

public INDArray adjustMasks(INDArray label, INDArray labelMask, int minorityLabel, double targetDist) {
    if (labelMask == null) {
        labelMask = Nd4j.ones(label.size(0), label.size(2));
    }
    validateData(label, labelMask);
    INDArray bernoullis = Nd4j.zeros(labelMask.shape());
    int currentTimeSliceEnd = label.size(2);
    // iterate over each tbptt window
    while (currentTimeSliceEnd > 0) {
        int currentTimeSliceStart = Math.max(currentTimeSliceEnd - tbpttWindowSize, 0);
        // get views for current time slice
        INDArray currentWindowBernoulli = bernoullis.get(NDArrayIndex.all(), NDArrayIndex.interval(currentTimeSliceStart, currentTimeSliceEnd));
        INDArray currentMask = labelMask.get(NDArrayIndex.all(), NDArrayIndex.interval(currentTimeSliceStart, currentTimeSliceEnd));
        INDArray currentLabel;
        if (label.size(1) == 2) {
            // if one hot grab the right index
            currentLabel = label.get(NDArrayIndex.all(), NDArrayIndex.point(minorityLabel), NDArrayIndex.interval(currentTimeSliceStart, currentTimeSliceEnd));
        } else {
            currentLabel = label.get(NDArrayIndex.all(), NDArrayIndex.point(0), NDArrayIndex.interval(currentTimeSliceStart, currentTimeSliceEnd));
            if (minorityLabel == 0) {
                currentLabel = Transforms.not(currentLabel);
            }
        }
        // calculate required probabilities and write into the view
        currentWindowBernoulli.assign(calculateBernoulli(currentLabel, currentMask, targetDist));
        currentTimeSliceEnd = currentTimeSliceStart;
    }
    return Nd4j.getExecutioner().exec(new BernoulliDistribution(Nd4j.createUninitialized(bernoullis.shape()), bernoullis), Nd4j.getRandom());
}
Also used : INDArray(org.nd4j.linalg.api.ndarray.INDArray) BernoulliDistribution(org.nd4j.linalg.api.ops.random.impl.BernoulliDistribution)

Aggregations

INDArray (org.nd4j.linalg.api.ndarray.INDArray)12 BernoulliDistribution (org.nd4j.linalg.api.ops.random.impl.BernoulliDistribution)12 Test (org.junit.Test)10 SDVariable (org.nd4j.autodiff.samediff.SDVariable)4 SameDiff (org.nd4j.autodiff.samediff.SameDiff)4 LossFunctions (org.nd4j.autodiff.loss.LossFunctions)3 LossInfo (org.nd4j.autodiff.loss.LossInfo)3 BaseNd4jTest (org.nd4j.linalg.BaseNd4jTest)3 MultiLayerConfiguration (org.deeplearning4j.nn.conf.MultiLayerConfiguration)2 NeuralNetConfiguration (org.deeplearning4j.nn.conf.NeuralNetConfiguration)2 NormalDistribution (org.deeplearning4j.nn.conf.distribution.NormalDistribution)2 MultiLayerNetwork (org.deeplearning4j.nn.multilayer.MultiLayerNetwork)2 ActivationTanH (org.nd4j.linalg.activations.impl.ActivationTanH)2 ArrayList (java.util.ArrayList)1 org.deeplearning4j.nn.api (org.deeplearning4j.nn.api)1 org.deeplearning4j.nn.conf.layers.variational (org.deeplearning4j.nn.conf.layers.variational)1 VariationalAutoencoder (org.deeplearning4j.nn.conf.layers.variational.VariationalAutoencoder)1 DifferentialFunction (org.nd4j.autodiff.functions.DifferentialFunction)1 IActivation (org.nd4j.linalg.activations.IActivation)1 ActivationIdentity (org.nd4j.linalg.activations.impl.ActivationIdentity)1