use of org.nd4j.linalg.api.ops.random.impl.BernoulliDistribution in project nd4j by deeplearning4j.
the class LossFunctionTest method testClippingXENT.
@Test
public void testClippingXENT() throws Exception {
ILossFunction l1 = new LossBinaryXENT(0);
ILossFunction l2 = new LossBinaryXENT();
INDArray labels = Nd4j.getExecutioner().exec(new BernoulliDistribution(Nd4j.create(3, 5), 0.5));
INDArray preOut = Nd4j.valueArrayOf(3, 5, -1000.0);
IActivation a = new ActivationSigmoid();
double score1 = l1.computeScore(labels, preOut.dup(), a, null, false);
assertTrue(Double.isNaN(score1));
double score2 = l2.computeScore(labels, preOut.dup(), a, null, false);
assertFalse(Double.isNaN(score2));
INDArray grad1 = l1.computeGradient(labels, preOut.dup(), a, null);
INDArray grad2 = l2.computeGradient(labels, preOut.dup(), a, null);
MatchCondition c1 = new MatchCondition(grad1, Conditions.isNan());
MatchCondition c2 = new MatchCondition(grad2, Conditions.isNan());
int match1 = Nd4j.getExecutioner().exec(c1, Integer.MAX_VALUE).getInt(0);
int match2 = Nd4j.getExecutioner().exec(c2, Integer.MAX_VALUE).getInt(0);
assertTrue(match1 > 0);
assertEquals(0, match2);
}
use of org.nd4j.linalg.api.ops.random.impl.BernoulliDistribution in project nd4j by deeplearning4j.
the class MultiDataSetTest method testMergingWithPerOutputMasking.
@Test
public void testMergingWithPerOutputMasking() {
// Test 2d mask merging, 2d data
// features
INDArray f2d1 = Nd4j.create(new double[] { 1, 2, 3 });
INDArray f2d2 = Nd4j.create(new double[][] { { 4, 5, 6 }, { 7, 8, 9 } });
// labels
INDArray l2d1 = Nd4j.create(new double[] { 1.5, 2.5, 3.5 });
INDArray l2d2 = Nd4j.create(new double[][] { { 4.5, 5.5, 6.5 }, { 7.5, 8.5, 9.5 } });
// feature masks
INDArray fm2d1 = Nd4j.create(new double[] { 0, 1, 1 });
INDArray fm2d2 = Nd4j.create(new double[][] { { 1, 0, 1 }, { 0, 1, 0 } });
// label masks
INDArray lm2d1 = Nd4j.create(new double[] { 1, 1, 0 });
INDArray lm2d2 = Nd4j.create(new double[][] { { 1, 0, 0 }, { 0, 1, 1 } });
MultiDataSet mds2d1 = new MultiDataSet(f2d1, l2d1, fm2d1, lm2d1);
MultiDataSet mds2d2 = new MultiDataSet(f2d2, l2d2, fm2d2, lm2d2);
MultiDataSet merged = MultiDataSet.merge(Arrays.asList(mds2d1, mds2d2));
INDArray expFeatures2d = Nd4j.create(new double[][] { { 1, 2, 3 }, { 4, 5, 6 }, { 7, 8, 9 } });
INDArray expLabels2d = Nd4j.create(new double[][] { { 1.5, 2.5, 3.5 }, { 4.5, 5.5, 6.5 }, { 7.5, 8.5, 9.5 } });
INDArray expFM2d = Nd4j.create(new double[][] { { 0, 1, 1 }, { 1, 0, 1 }, { 0, 1, 0 } });
INDArray expLM2d = Nd4j.create(new double[][] { { 1, 1, 0 }, { 1, 0, 0 }, { 0, 1, 1 } });
MultiDataSet mdsExp2d = new MultiDataSet(expFeatures2d, expLabels2d, expFM2d, expLM2d);
assertEquals(mdsExp2d, merged);
// Test 4d features, 2d labels, 2d masks
INDArray f4d1 = Nd4j.create(1, 3, 5, 5);
INDArray f4d2 = Nd4j.create(2, 3, 5, 5);
MultiDataSet mds4d1 = new MultiDataSet(f4d1, l2d1, null, lm2d1);
MultiDataSet mds4d2 = new MultiDataSet(f4d2, l2d2, null, lm2d2);
MultiDataSet merged4d = MultiDataSet.merge(Arrays.asList(mds4d1, mds4d2));
assertEquals(expLabels2d, merged4d.getLabels(0));
assertEquals(expLM2d, merged4d.getLabelsMaskArray(0));
// Test 3d mask merging, 3d data
INDArray f3d1 = Nd4j.create(1, 3, 4);
INDArray f3d2 = Nd4j.create(1, 3, 3);
INDArray l3d1 = Nd4j.getExecutioner().exec(new BernoulliDistribution(Nd4j.create(1, 3, 4), 0.5));
INDArray l3d2 = Nd4j.getExecutioner().exec(new BernoulliDistribution(Nd4j.create(2, 3, 3), 0.5));
INDArray lm3d1 = Nd4j.getExecutioner().exec(new BernoulliDistribution(Nd4j.create(1, 3, 4), 0.5));
INDArray lm3d2 = Nd4j.getExecutioner().exec(new BernoulliDistribution(Nd4j.create(2, 3, 3), 0.5));
MultiDataSet mds3d1 = new MultiDataSet(f3d1, l3d1, null, lm3d1);
MultiDataSet mds3d2 = new MultiDataSet(f3d2, l3d2, null, lm3d2);
INDArray expLabels3d = Nd4j.create(3, 3, 4);
expLabels3d.put(new INDArrayIndex[] { NDArrayIndex.point(0), NDArrayIndex.all(), NDArrayIndex.interval(0, 4) }, l3d1);
expLabels3d.put(new INDArrayIndex[] { NDArrayIndex.interval(1, 2, true), NDArrayIndex.all(), NDArrayIndex.interval(0, 3) }, l3d2);
INDArray expLM3d = Nd4j.create(3, 3, 4);
expLM3d.put(new INDArrayIndex[] { NDArrayIndex.point(0), NDArrayIndex.all(), NDArrayIndex.interval(0, 4) }, lm3d1);
expLM3d.put(new INDArrayIndex[] { NDArrayIndex.interval(1, 2, true), NDArrayIndex.all(), NDArrayIndex.interval(0, 3) }, lm3d2);
MultiDataSet merged3d = MultiDataSet.merge(Arrays.asList(mds3d1, mds3d2));
assertEquals(expLabels3d, merged3d.getLabels(0));
assertEquals(expLM3d, merged3d.getLabelsMaskArray(0));
// Test 3d features, 2d masks, 2d output (for example: RNN -> global pooling w/ per-output masking)
MultiDataSet mds3d2d1 = new MultiDataSet(f3d1, l2d1, null, lm2d1);
MultiDataSet mds3d2d2 = new MultiDataSet(f3d2, l2d2, null, lm2d2);
MultiDataSet merged3d2d = MultiDataSet.merge(Arrays.asList(mds3d2d1, mds3d2d2));
assertEquals(expLabels2d, merged3d2d.getLabels(0));
assertEquals(expLM2d, merged3d2d.getLabelsMaskArray(0));
}
Aggregations