use of org.deeplearning4j.datasets.iterator.ExistingDataSetIterator in project deeplearning4j by deeplearning4j.
the class TestMasking method checkMaskArrayClearance.
@Test
public void checkMaskArrayClearance() {
for (boolean tbptt : new boolean[] { true, false }) {
//Simple "does it throw an exception" type test...
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().iterations(1).seed(12345).list().layer(0, new RnnOutputLayer.Builder(LossFunctions.LossFunction.MSE).activation(Activation.IDENTITY).nIn(1).nOut(1).build()).backpropType(tbptt ? BackpropType.TruncatedBPTT : BackpropType.Standard).tBPTTForwardLength(8).tBPTTBackwardLength(8).build();
MultiLayerNetwork net = new MultiLayerNetwork(conf);
net.init();
DataSet data = new DataSet(Nd4j.linspace(1, 10, 10).reshape(1, 1, 10), Nd4j.linspace(2, 20, 10).reshape(1, 1, 10), Nd4j.ones(10), Nd4j.ones(10));
net.fit(data);
for (Layer l : net.getLayers()) {
assertNull(l.getMaskArray());
}
net.fit(data.getFeatures(), data.getLabels(), data.getFeaturesMaskArray(), data.getLabelsMaskArray());
for (Layer l : net.getLayers()) {
assertNull(l.getMaskArray());
}
DataSetIterator iter = new ExistingDataSetIterator(Collections.singletonList(data).iterator());
net.fit(iter);
for (Layer l : net.getLayers()) {
assertNull(l.getMaskArray());
}
}
}
use of org.deeplearning4j.datasets.iterator.ExistingDataSetIterator in project deeplearning4j by deeplearning4j.
the class MagicQueueTest method testSequentialIterable.
@Test
public void testSequentialIterable() throws Exception {
List<DataSet> list = new ArrayList<>();
for (int i = 0; i < 1024; i++) list.add(new DataSet(Nd4j.create(new float[] { 1f, 2f, 3f }), Nd4j.create(new float[] { 1f, 2f, 3f })));
int numDevices = Nd4j.getAffinityManager().getNumberOfDevices();
ExistingDataSetIterator edsi = new ExistingDataSetIterator(list);
MagicQueue queue = new MagicQueue.Builder().setMode(MagicQueue.Mode.SEQUENTIAL).setCapacityPerFlow(32).build();
AsyncDataSetIterator adsi = new AsyncDataSetIterator(edsi, 10, queue);
int cnt = 0;
while (adsi.hasNext()) {
DataSet ds = adsi.next();
// making sure dataset isn't null
assertNotEquals("Failed on round " + cnt, null, ds);
// making sure device for this array is a "next one"
assertEquals(cnt % numDevices, Nd4j.getAffinityManager().getDeviceForArray(ds.getFeatures()).intValue());
assertEquals(cnt % numDevices, Nd4j.getAffinityManager().getDeviceForArray(ds.getLabels()).intValue());
cnt++;
}
assertEquals(list.size(), cnt);
}
Aggregations