use of org.deeplearning4j.datasets.iterator.IteratorDataSetIterator in project deeplearning4j by deeplearning4j.
the class ScoreFlatMapFunctionCGDataSetAdapter method call.
@Override
public Iterable<Tuple2<Integer, Double>> call(Iterator<DataSet> dataSetIterator) throws Exception {
if (!dataSetIterator.hasNext()) {
return Collections.singletonList(new Tuple2<>(0, 0.0));
}
//Does batching where appropriate
DataSetIterator iter = new IteratorDataSetIterator(dataSetIterator, minibatchSize);
ComputationGraph network = new ComputationGraph(ComputationGraphConfiguration.fromJson(json));
network.init();
//.value() is shared by all executors on single machine -> OK, as params are not changed in score function
INDArray val = params.value().unsafeDuplication();
if (val.length() != network.numParams(false))
throw new IllegalStateException("Network did not have same number of parameters as the broadcast set parameters");
network.setParams(val);
List<Tuple2<Integer, Double>> out = new ArrayList<>();
while (iter.hasNext()) {
DataSet ds = iter.next();
double score = network.score(ds, false);
int numExamples = ds.getFeatureMatrix().size(0);
out.add(new Tuple2<>(numExamples, score * numExamples));
}
if (Nd4j.getExecutioner() instanceof GridExecutioner)
((GridExecutioner) Nd4j.getExecutioner()).flushQueueBlocking();
return out;
}
use of org.deeplearning4j.datasets.iterator.IteratorDataSetIterator in project deeplearning4j by deeplearning4j.
the class ComputationGraphTestRNN method checkMaskArrayClearance.
@Test
public void checkMaskArrayClearance() {
for (boolean tbptt : new boolean[] { true, false }) {
//Simple "does it throw an exception" type test...
ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder().iterations(1).seed(12345).graphBuilder().addInputs("in").addLayer("out", new RnnOutputLayer.Builder(LossFunctions.LossFunction.MSE).activation(Activation.IDENTITY).nIn(1).nOut(1).build(), "in").setOutputs("out").backpropType(tbptt ? BackpropType.TruncatedBPTT : BackpropType.Standard).tBPTTForwardLength(8).tBPTTBackwardLength(8).build();
ComputationGraph net = new ComputationGraph(conf);
net.init();
MultiDataSet data = new MultiDataSet(new INDArray[] { Nd4j.linspace(1, 10, 10).reshape(1, 1, 10) }, new INDArray[] { Nd4j.linspace(2, 20, 10).reshape(1, 1, 10) }, new INDArray[] { Nd4j.ones(10) }, new INDArray[] { Nd4j.ones(10) });
net.fit(data);
assertNull(net.getInputMaskArrays());
assertNull(net.getLabelMaskArrays());
for (Layer l : net.getLayers()) {
assertNull(l.getMaskArray());
}
DataSet ds = new DataSet(data.getFeatures(0), data.getLabels(0), data.getFeaturesMaskArray(0), data.getLabelsMaskArray(0));
net.fit(ds);
assertNull(net.getInputMaskArrays());
assertNull(net.getLabelMaskArrays());
for (Layer l : net.getLayers()) {
assertNull(l.getMaskArray());
}
net.fit(data.getFeatures(), data.getLabels(), data.getFeaturesMaskArrays(), data.getLabelsMaskArrays());
assertNull(net.getInputMaskArrays());
assertNull(net.getLabelMaskArrays());
for (Layer l : net.getLayers()) {
assertNull(l.getMaskArray());
}
MultiDataSetIterator iter = new IteratorMultiDataSetIterator(Collections.singletonList((org.nd4j.linalg.dataset.api.MultiDataSet) data).iterator(), 1);
net.fit(iter);
assertNull(net.getInputMaskArrays());
assertNull(net.getLabelMaskArrays());
for (Layer l : net.getLayers()) {
assertNull(l.getMaskArray());
}
DataSetIterator iter2 = new IteratorDataSetIterator(Collections.singletonList(ds).iterator(), 1);
net.fit(iter2);
assertNull(net.getInputMaskArrays());
assertNull(net.getLabelMaskArrays());
for (Layer l : net.getLayers()) {
assertNull(l.getMaskArray());
}
}
}
use of org.deeplearning4j.datasets.iterator.IteratorDataSetIterator in project deeplearning4j by deeplearning4j.
the class ExecuteWorkerFlatMapAdapter method call.
@Override
public Iterable<R> call(Iterator<DataSet> dataSetIterator) throws Exception {
WorkerConfiguration dataConfig = worker.getDataConfiguration();
final boolean isGraph = dataConfig.isGraphNetwork();
boolean stats = dataConfig.isCollectTrainingStats();
StatsCalculationHelper s = (stats ? new StatsCalculationHelper() : null);
if (stats)
s.logMethodStartTime();
if (!dataSetIterator.hasNext()) {
if (stats) {
s.logReturnTime();
Pair<R, SparkTrainingStats> pair = worker.getFinalResultNoDataWithStats();
pair.getFirst().setStats(s.build(pair.getSecond()));
return Collections.singletonList(pair.getFirst());
} else {
return Collections.singletonList(worker.getFinalResultNoData());
}
}
int batchSize = dataConfig.getBatchSizePerWorker();
final int prefetchCount = dataConfig.getPrefetchNumBatches();
DataSetIterator batchedIterator = new IteratorDataSetIterator(dataSetIterator, batchSize);
if (prefetchCount > 0) {
batchedIterator = new AsyncDataSetIterator(batchedIterator, prefetchCount);
}
try {
MultiLayerNetwork net = null;
ComputationGraph graph = null;
if (stats)
s.logInitialModelBefore();
if (isGraph)
graph = worker.getInitialModelGraph();
else
net = worker.getInitialModel();
if (stats)
s.logInitialModelAfter();
int miniBatchCount = 0;
int maxMinibatches = (dataConfig.getMaxBatchesPerWorker() > 0 ? dataConfig.getMaxBatchesPerWorker() : Integer.MAX_VALUE);
while (batchedIterator.hasNext() && miniBatchCount++ < maxMinibatches) {
if (stats)
s.logNextDataSetBefore();
DataSet next = batchedIterator.next();
if (stats)
s.logNextDataSetAfter(next.numExamples());
if (stats) {
s.logProcessMinibatchBefore();
Pair<R, SparkTrainingStats> result;
if (isGraph)
result = worker.processMinibatchWithStats(next, graph, !batchedIterator.hasNext());
else
result = worker.processMinibatchWithStats(next, net, !batchedIterator.hasNext());
s.logProcessMinibatchAfter();
if (result != null) {
//Terminate training immediately
s.logReturnTime();
SparkTrainingStats workerStats = result.getSecond();
SparkTrainingStats returnStats = s.build(workerStats);
result.getFirst().setStats(returnStats);
return Collections.singletonList(result.getFirst());
}
} else {
R result;
if (isGraph)
result = worker.processMinibatch(next, graph, !batchedIterator.hasNext());
else
result = worker.processMinibatch(next, net, !batchedIterator.hasNext());
if (result != null) {
//Terminate training immediately
return Collections.singletonList(result);
}
}
}
//For some reason, we didn't return already. Normally this shouldn't happen
if (stats) {
s.logReturnTime();
Pair<R, SparkTrainingStats> pair;
if (isGraph)
pair = worker.getFinalResultWithStats(graph);
else
pair = worker.getFinalResultWithStats(net);
pair.getFirst().setStats(s.build(pair.getSecond()));
return Collections.singletonList(pair.getFirst());
} else {
if (isGraph)
return Collections.singletonList(worker.getFinalResult(graph));
else
return Collections.singletonList(worker.getFinalResult(net));
}
} finally {
//Make sure we shut down the async thread properly...
if (batchedIterator instanceof AsyncDataSetIterator) {
((AsyncDataSetIterator) batchedIterator).shutdown();
}
if (Nd4j.getExecutioner() instanceof GridExecutioner)
((GridExecutioner) Nd4j.getExecutioner()).flushQueueBlocking();
}
}
use of org.deeplearning4j.datasets.iterator.IteratorDataSetIterator in project deeplearning4j by deeplearning4j.
the class ScoreFlatMapFunctionAdapter method call.
@Override
public Iterable<Tuple2<Integer, Double>> call(Iterator<DataSet> dataSetIterator) throws Exception {
if (!dataSetIterator.hasNext()) {
return Collections.singletonList(new Tuple2<>(0, 0.0));
}
//Does batching where appropriate
DataSetIterator iter = new IteratorDataSetIterator(dataSetIterator, minibatchSize);
MultiLayerNetwork network = new MultiLayerNetwork(MultiLayerConfiguration.fromJson(json));
network.init();
//.value() object will be shared by all executors on each machine -> OK, as params are not modified by score function
INDArray val = params.value().unsafeDuplication();
if (val.length() != network.numParams(false))
throw new IllegalStateException("Network did not have same number of parameters as the broadcast set parameters");
network.setParameters(val);
List<Tuple2<Integer, Double>> out = new ArrayList<>();
while (iter.hasNext()) {
DataSet ds = iter.next();
double score = network.score(ds, false);
int numExamples = ds.getFeatureMatrix().size(0);
out.add(new Tuple2<>(numExamples, score * numExamples));
}
if (Nd4j.getExecutioner() instanceof GridExecutioner)
((GridExecutioner) Nd4j.getExecutioner()).flushQueueBlocking();
return out;
}
Aggregations