Search in sources :

Example 11 with BernoulliDistribution

use of org.nd4j.linalg.api.ops.random.impl.BernoulliDistribution in project nd4j by deeplearning4j.

the class LossFunctionTest method testClippingXENT.

@Test
public void testClippingXENT() throws Exception {
    ILossFunction l1 = new LossBinaryXENT(0);
    ILossFunction l2 = new LossBinaryXENT();
    INDArray labels = Nd4j.getExecutioner().exec(new BernoulliDistribution(Nd4j.create(3, 5), 0.5));
    INDArray preOut = Nd4j.valueArrayOf(3, 5, -1000.0);
    IActivation a = new ActivationSigmoid();
    double score1 = l1.computeScore(labels, preOut.dup(), a, null, false);
    assertTrue(Double.isNaN(score1));
    double score2 = l2.computeScore(labels, preOut.dup(), a, null, false);
    assertFalse(Double.isNaN(score2));
    INDArray grad1 = l1.computeGradient(labels, preOut.dup(), a, null);
    INDArray grad2 = l2.computeGradient(labels, preOut.dup(), a, null);
    MatchCondition c1 = new MatchCondition(grad1, Conditions.isNan());
    MatchCondition c2 = new MatchCondition(grad2, Conditions.isNan());
    int match1 = Nd4j.getExecutioner().exec(c1, Integer.MAX_VALUE).getInt(0);
    int match2 = Nd4j.getExecutioner().exec(c2, Integer.MAX_VALUE).getInt(0);
    assertTrue(match1 > 0);
    assertEquals(0, match2);
}
Also used : INDArray(org.nd4j.linalg.api.ndarray.INDArray) BernoulliDistribution(org.nd4j.linalg.api.ops.random.impl.BernoulliDistribution) ActivationSigmoid(org.nd4j.linalg.activations.impl.ActivationSigmoid) MatchCondition(org.nd4j.linalg.api.ops.impl.accum.MatchCondition) IActivation(org.nd4j.linalg.activations.IActivation) Test(org.junit.Test) BaseNd4jTest(org.nd4j.linalg.BaseNd4jTest)

Example 12 with BernoulliDistribution

use of org.nd4j.linalg.api.ops.random.impl.BernoulliDistribution in project nd4j by deeplearning4j.

the class MultiDataSetTest method testMergingWithPerOutputMasking.

@Test
public void testMergingWithPerOutputMasking() {
    // Test 2d mask merging, 2d data
    // features
    INDArray f2d1 = Nd4j.create(new double[] { 1, 2, 3 });
    INDArray f2d2 = Nd4j.create(new double[][] { { 4, 5, 6 }, { 7, 8, 9 } });
    // labels
    INDArray l2d1 = Nd4j.create(new double[] { 1.5, 2.5, 3.5 });
    INDArray l2d2 = Nd4j.create(new double[][] { { 4.5, 5.5, 6.5 }, { 7.5, 8.5, 9.5 } });
    // feature masks
    INDArray fm2d1 = Nd4j.create(new double[] { 0, 1, 1 });
    INDArray fm2d2 = Nd4j.create(new double[][] { { 1, 0, 1 }, { 0, 1, 0 } });
    // label masks
    INDArray lm2d1 = Nd4j.create(new double[] { 1, 1, 0 });
    INDArray lm2d2 = Nd4j.create(new double[][] { { 1, 0, 0 }, { 0, 1, 1 } });
    MultiDataSet mds2d1 = new MultiDataSet(f2d1, l2d1, fm2d1, lm2d1);
    MultiDataSet mds2d2 = new MultiDataSet(f2d2, l2d2, fm2d2, lm2d2);
    MultiDataSet merged = MultiDataSet.merge(Arrays.asList(mds2d1, mds2d2));
    INDArray expFeatures2d = Nd4j.create(new double[][] { { 1, 2, 3 }, { 4, 5, 6 }, { 7, 8, 9 } });
    INDArray expLabels2d = Nd4j.create(new double[][] { { 1.5, 2.5, 3.5 }, { 4.5, 5.5, 6.5 }, { 7.5, 8.5, 9.5 } });
    INDArray expFM2d = Nd4j.create(new double[][] { { 0, 1, 1 }, { 1, 0, 1 }, { 0, 1, 0 } });
    INDArray expLM2d = Nd4j.create(new double[][] { { 1, 1, 0 }, { 1, 0, 0 }, { 0, 1, 1 } });
    MultiDataSet mdsExp2d = new MultiDataSet(expFeatures2d, expLabels2d, expFM2d, expLM2d);
    assertEquals(mdsExp2d, merged);
    // Test 4d features, 2d labels, 2d masks
    INDArray f4d1 = Nd4j.create(1, 3, 5, 5);
    INDArray f4d2 = Nd4j.create(2, 3, 5, 5);
    MultiDataSet mds4d1 = new MultiDataSet(f4d1, l2d1, null, lm2d1);
    MultiDataSet mds4d2 = new MultiDataSet(f4d2, l2d2, null, lm2d2);
    MultiDataSet merged4d = MultiDataSet.merge(Arrays.asList(mds4d1, mds4d2));
    assertEquals(expLabels2d, merged4d.getLabels(0));
    assertEquals(expLM2d, merged4d.getLabelsMaskArray(0));
    // Test 3d mask merging, 3d data
    INDArray f3d1 = Nd4j.create(1, 3, 4);
    INDArray f3d2 = Nd4j.create(1, 3, 3);
    INDArray l3d1 = Nd4j.getExecutioner().exec(new BernoulliDistribution(Nd4j.create(1, 3, 4), 0.5));
    INDArray l3d2 = Nd4j.getExecutioner().exec(new BernoulliDistribution(Nd4j.create(2, 3, 3), 0.5));
    INDArray lm3d1 = Nd4j.getExecutioner().exec(new BernoulliDistribution(Nd4j.create(1, 3, 4), 0.5));
    INDArray lm3d2 = Nd4j.getExecutioner().exec(new BernoulliDistribution(Nd4j.create(2, 3, 3), 0.5));
    MultiDataSet mds3d1 = new MultiDataSet(f3d1, l3d1, null, lm3d1);
    MultiDataSet mds3d2 = new MultiDataSet(f3d2, l3d2, null, lm3d2);
    INDArray expLabels3d = Nd4j.create(3, 3, 4);
    expLabels3d.put(new INDArrayIndex[] { NDArrayIndex.point(0), NDArrayIndex.all(), NDArrayIndex.interval(0, 4) }, l3d1);
    expLabels3d.put(new INDArrayIndex[] { NDArrayIndex.interval(1, 2, true), NDArrayIndex.all(), NDArrayIndex.interval(0, 3) }, l3d2);
    INDArray expLM3d = Nd4j.create(3, 3, 4);
    expLM3d.put(new INDArrayIndex[] { NDArrayIndex.point(0), NDArrayIndex.all(), NDArrayIndex.interval(0, 4) }, lm3d1);
    expLM3d.put(new INDArrayIndex[] { NDArrayIndex.interval(1, 2, true), NDArrayIndex.all(), NDArrayIndex.interval(0, 3) }, lm3d2);
    MultiDataSet merged3d = MultiDataSet.merge(Arrays.asList(mds3d1, mds3d2));
    assertEquals(expLabels3d, merged3d.getLabels(0));
    assertEquals(expLM3d, merged3d.getLabelsMaskArray(0));
    // Test 3d features, 2d masks, 2d output (for example: RNN -> global pooling w/ per-output masking)
    MultiDataSet mds3d2d1 = new MultiDataSet(f3d1, l2d1, null, lm2d1);
    MultiDataSet mds3d2d2 = new MultiDataSet(f3d2, l2d2, null, lm2d2);
    MultiDataSet merged3d2d = MultiDataSet.merge(Arrays.asList(mds3d2d1, mds3d2d2));
    assertEquals(expLabels2d, merged3d2d.getLabels(0));
    assertEquals(expLM2d, merged3d2d.getLabelsMaskArray(0));
}
Also used : INDArray(org.nd4j.linalg.api.ndarray.INDArray) BernoulliDistribution(org.nd4j.linalg.api.ops.random.impl.BernoulliDistribution) Test(org.junit.Test) BaseNd4jTest(org.nd4j.linalg.BaseNd4jTest)

Aggregations

INDArray (org.nd4j.linalg.api.ndarray.INDArray)12 BernoulliDistribution (org.nd4j.linalg.api.ops.random.impl.BernoulliDistribution)12 Test (org.junit.Test)10 SDVariable (org.nd4j.autodiff.samediff.SDVariable)4 SameDiff (org.nd4j.autodiff.samediff.SameDiff)4 LossFunctions (org.nd4j.autodiff.loss.LossFunctions)3 LossInfo (org.nd4j.autodiff.loss.LossInfo)3 BaseNd4jTest (org.nd4j.linalg.BaseNd4jTest)3 MultiLayerConfiguration (org.deeplearning4j.nn.conf.MultiLayerConfiguration)2 NeuralNetConfiguration (org.deeplearning4j.nn.conf.NeuralNetConfiguration)2 NormalDistribution (org.deeplearning4j.nn.conf.distribution.NormalDistribution)2 MultiLayerNetwork (org.deeplearning4j.nn.multilayer.MultiLayerNetwork)2 ActivationTanH (org.nd4j.linalg.activations.impl.ActivationTanH)2 ArrayList (java.util.ArrayList)1 org.deeplearning4j.nn.api (org.deeplearning4j.nn.api)1 org.deeplearning4j.nn.conf.layers.variational (org.deeplearning4j.nn.conf.layers.variational)1 VariationalAutoencoder (org.deeplearning4j.nn.conf.layers.variational.VariationalAutoencoder)1 DifferentialFunction (org.nd4j.autodiff.functions.DifferentialFunction)1 IActivation (org.nd4j.linalg.activations.IActivation)1 ActivationIdentity (org.nd4j.linalg.activations.impl.ActivationIdentity)1