use of org.deeplearning4j.datasets.iterator.IteratorMultiDataSetIterator in project deeplearning4j by deeplearning4j.
the class ComputationGraphTestRNN method checkMaskArrayClearance.
@Test
public void checkMaskArrayClearance() {
for (boolean tbptt : new boolean[] { true, false }) {
//Simple "does it throw an exception" type test...
ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder().iterations(1).seed(12345).graphBuilder().addInputs("in").addLayer("out", new RnnOutputLayer.Builder(LossFunctions.LossFunction.MSE).activation(Activation.IDENTITY).nIn(1).nOut(1).build(), "in").setOutputs("out").backpropType(tbptt ? BackpropType.TruncatedBPTT : BackpropType.Standard).tBPTTForwardLength(8).tBPTTBackwardLength(8).build();
ComputationGraph net = new ComputationGraph(conf);
net.init();
MultiDataSet data = new MultiDataSet(new INDArray[] { Nd4j.linspace(1, 10, 10).reshape(1, 1, 10) }, new INDArray[] { Nd4j.linspace(2, 20, 10).reshape(1, 1, 10) }, new INDArray[] { Nd4j.ones(10) }, new INDArray[] { Nd4j.ones(10) });
net.fit(data);
assertNull(net.getInputMaskArrays());
assertNull(net.getLabelMaskArrays());
for (Layer l : net.getLayers()) {
assertNull(l.getMaskArray());
}
DataSet ds = new DataSet(data.getFeatures(0), data.getLabels(0), data.getFeaturesMaskArray(0), data.getLabelsMaskArray(0));
net.fit(ds);
assertNull(net.getInputMaskArrays());
assertNull(net.getLabelMaskArrays());
for (Layer l : net.getLayers()) {
assertNull(l.getMaskArray());
}
net.fit(data.getFeatures(), data.getLabels(), data.getFeaturesMaskArrays(), data.getLabelsMaskArrays());
assertNull(net.getInputMaskArrays());
assertNull(net.getLabelMaskArrays());
for (Layer l : net.getLayers()) {
assertNull(l.getMaskArray());
}
MultiDataSetIterator iter = new IteratorMultiDataSetIterator(Collections.singletonList((org.nd4j.linalg.dataset.api.MultiDataSet) data).iterator(), 1);
net.fit(iter);
assertNull(net.getInputMaskArrays());
assertNull(net.getLabelMaskArrays());
for (Layer l : net.getLayers()) {
assertNull(l.getMaskArray());
}
DataSetIterator iter2 = new IteratorDataSetIterator(Collections.singletonList(ds).iterator(), 1);
net.fit(iter2);
assertNull(net.getInputMaskArrays());
assertNull(net.getLabelMaskArrays());
for (Layer l : net.getLayers()) {
assertNull(l.getMaskArray());
}
}
}
use of org.deeplearning4j.datasets.iterator.IteratorMultiDataSetIterator in project deeplearning4j by deeplearning4j.
the class ScoreFlatMapFunctionCGMultiDataSetAdapter method call.
@Override
public Iterable<Tuple2<Integer, Double>> call(Iterator<MultiDataSet> dataSetIterator) throws Exception {
if (!dataSetIterator.hasNext()) {
return Collections.singletonList(new Tuple2<>(0, 0.0));
}
//Does batching where appropriate
MultiDataSetIterator iter = new IteratorMultiDataSetIterator(dataSetIterator, minibatchSize);
ComputationGraph network = new ComputationGraph(ComputationGraphConfiguration.fromJson(json));
network.init();
//.value() is shared by all executors on single machine -> OK, as params are not changed in score function
INDArray val = params.value().unsafeDuplication();
if (val.length() != network.numParams(false))
throw new IllegalStateException("Network did not have same number of parameters as the broadcast set parameters");
network.setParams(val);
List<Tuple2<Integer, Double>> out = new ArrayList<>();
while (iter.hasNext()) {
MultiDataSet ds = iter.next();
double score = network.score(ds, false);
int numExamples = ds.getFeatures(0).size(0);
out.add(new Tuple2<>(numExamples, score * numExamples));
}
if (Nd4j.getExecutioner() instanceof GridExecutioner)
((GridExecutioner) Nd4j.getExecutioner()).flushQueueBlocking();
return out;
}
use of org.deeplearning4j.datasets.iterator.IteratorMultiDataSetIterator in project deeplearning4j by deeplearning4j.
the class ExecuteWorkerMultiDataSetFlatMapAdapter method call.
@Override
public Iterable<R> call(Iterator<MultiDataSet> dataSetIterator) throws Exception {
WorkerConfiguration dataConfig = worker.getDataConfiguration();
boolean stats = dataConfig.isCollectTrainingStats();
StatsCalculationHelper s = (stats ? new StatsCalculationHelper() : null);
if (stats)
s.logMethodStartTime();
if (!dataSetIterator.hasNext()) {
if (stats)
s.logReturnTime();
//Sometimes: no data
return Collections.emptyList();
}
int batchSize = dataConfig.getBatchSizePerWorker();
final int prefetchCount = dataConfig.getPrefetchNumBatches();
MultiDataSetIterator batchedIterator = new IteratorMultiDataSetIterator(dataSetIterator, batchSize);
if (prefetchCount > 0) {
batchedIterator = new AsyncMultiDataSetIterator(batchedIterator, prefetchCount);
}
try {
if (stats)
s.logInitialModelBefore();
ComputationGraph net = worker.getInitialModelGraph();
if (stats)
s.logInitialModelAfter();
int miniBatchCount = 0;
int maxMinibatches = (dataConfig.getMaxBatchesPerWorker() > 0 ? dataConfig.getMaxBatchesPerWorker() : Integer.MAX_VALUE);
while (batchedIterator.hasNext() && miniBatchCount++ < maxMinibatches) {
if (stats)
s.logNextDataSetBefore();
MultiDataSet next = batchedIterator.next();
if (stats)
s.logNextDataSetAfter(next.getFeatures(0).size(0));
if (stats) {
s.logProcessMinibatchBefore();
Pair<R, SparkTrainingStats> result = worker.processMinibatchWithStats(next, net, !batchedIterator.hasNext());
s.logProcessMinibatchAfter();
if (result != null) {
//Terminate training immediately
s.logReturnTime();
SparkTrainingStats workerStats = result.getSecond();
SparkTrainingStats returnStats = s.build(workerStats);
result.getFirst().setStats(returnStats);
return Collections.singletonList(result.getFirst());
}
} else {
R result = worker.processMinibatch(next, net, !batchedIterator.hasNext());
if (result != null) {
//Terminate training immediately
return Collections.singletonList(result);
}
}
}
//For some reason, we didn't return already. Normally this shouldn't happen
if (stats) {
s.logReturnTime();
Pair<R, SparkTrainingStats> pair = worker.getFinalResultWithStats(net);
pair.getFirst().setStats(s.build(pair.getSecond()));
return Collections.singletonList(pair.getFirst());
} else {
return Collections.singletonList(worker.getFinalResult(net));
}
} finally {
//Make sure we shut down the async thread properly...
if (batchedIterator instanceof AsyncMultiDataSetIterator) {
((AsyncMultiDataSetIterator) batchedIterator).shutdown();
}
if (Nd4j.getExecutioner() instanceof GridExecutioner)
((GridExecutioner) Nd4j.getExecutioner()).flushQueueBlocking();
}
}
Aggregations