use of org.nd4j.linalg.dataset.api.MultiDataSet in project deeplearning4j by deeplearning4j.
the class ScoreFlatMapFunctionCGMultiDataSetAdapter method call.
@Override
public Iterable<Tuple2<Integer, Double>> call(Iterator<MultiDataSet> dataSetIterator) throws Exception {
if (!dataSetIterator.hasNext()) {
return Collections.singletonList(new Tuple2<>(0, 0.0));
}
//Does batching where appropriate
MultiDataSetIterator iter = new IteratorMultiDataSetIterator(dataSetIterator, minibatchSize);
ComputationGraph network = new ComputationGraph(ComputationGraphConfiguration.fromJson(json));
network.init();
//.value() is shared by all executors on single machine -> OK, as params are not changed in score function
INDArray val = params.value().unsafeDuplication();
if (val.length() != network.numParams(false))
throw new IllegalStateException("Network did not have same number of parameters as the broadcast set parameters");
network.setParams(val);
List<Tuple2<Integer, Double>> out = new ArrayList<>();
while (iter.hasNext()) {
MultiDataSet ds = iter.next();
double score = network.score(ds, false);
int numExamples = ds.getFeatures(0).size(0);
out.add(new Tuple2<>(numExamples, score * numExamples));
}
if (Nd4j.getExecutioner() instanceof GridExecutioner)
((GridExecutioner) Nd4j.getExecutioner()).flushQueueBlocking();
return out;
}
use of org.nd4j.linalg.dataset.api.MultiDataSet in project deeplearning4j by deeplearning4j.
the class ExecuteWorkerMultiDataSetFlatMapAdapter method call.
@Override
public Iterable<R> call(Iterator<MultiDataSet> dataSetIterator) throws Exception {
WorkerConfiguration dataConfig = worker.getDataConfiguration();
boolean stats = dataConfig.isCollectTrainingStats();
StatsCalculationHelper s = (stats ? new StatsCalculationHelper() : null);
if (stats)
s.logMethodStartTime();
if (!dataSetIterator.hasNext()) {
if (stats)
s.logReturnTime();
//Sometimes: no data
return Collections.emptyList();
}
int batchSize = dataConfig.getBatchSizePerWorker();
final int prefetchCount = dataConfig.getPrefetchNumBatches();
MultiDataSetIterator batchedIterator = new IteratorMultiDataSetIterator(dataSetIterator, batchSize);
if (prefetchCount > 0) {
batchedIterator = new AsyncMultiDataSetIterator(batchedIterator, prefetchCount);
}
try {
if (stats)
s.logInitialModelBefore();
ComputationGraph net = worker.getInitialModelGraph();
if (stats)
s.logInitialModelAfter();
int miniBatchCount = 0;
int maxMinibatches = (dataConfig.getMaxBatchesPerWorker() > 0 ? dataConfig.getMaxBatchesPerWorker() : Integer.MAX_VALUE);
while (batchedIterator.hasNext() && miniBatchCount++ < maxMinibatches) {
if (stats)
s.logNextDataSetBefore();
MultiDataSet next = batchedIterator.next();
if (stats)
s.logNextDataSetAfter(next.getFeatures(0).size(0));
if (stats) {
s.logProcessMinibatchBefore();
Pair<R, SparkTrainingStats> result = worker.processMinibatchWithStats(next, net, !batchedIterator.hasNext());
s.logProcessMinibatchAfter();
if (result != null) {
//Terminate training immediately
s.logReturnTime();
SparkTrainingStats workerStats = result.getSecond();
SparkTrainingStats returnStats = s.build(workerStats);
result.getFirst().setStats(returnStats);
return Collections.singletonList(result.getFirst());
}
} else {
R result = worker.processMinibatch(next, net, !batchedIterator.hasNext());
if (result != null) {
//Terminate training immediately
return Collections.singletonList(result);
}
}
}
//For some reason, we didn't return already. Normally this shouldn't happen
if (stats) {
s.logReturnTime();
Pair<R, SparkTrainingStats> pair = worker.getFinalResultWithStats(net);
pair.getFirst().setStats(s.build(pair.getSecond()));
return Collections.singletonList(pair.getFirst());
} else {
return Collections.singletonList(worker.getFinalResult(net));
}
} finally {
//Make sure we shut down the async thread properly...
if (batchedIterator instanceof AsyncMultiDataSetIterator) {
((AsyncMultiDataSetIterator) batchedIterator).shutdown();
}
if (Nd4j.getExecutioner() instanceof GridExecutioner)
((GridExecutioner) Nd4j.getExecutioner()).flushQueueBlocking();
}
}
use of org.nd4j.linalg.dataset.api.MultiDataSet in project deeplearning4j by deeplearning4j.
the class BatchAndExportMultiDataSetsFunction method processList.
private Pair<Integer, List<String>> processList(LinkedList<MultiDataSet> tempList, int partitionIdx, int countBefore, boolean finalExport) throws Exception {
//Go through the list. If we have enough examples: remove the DataSet objects, merge and export them. Otherwise: do nothing
int numExamples = 0;
for (MultiDataSet ds : tempList) {
numExamples += ds.getFeatures(0).size(0);
}
if (tempList.size() == 0 || (numExamples < minibatchSize && !finalExport)) {
//No op
return new Pair<>(countBefore, Collections.<String>emptyList());
}
List<String> exportPaths = new ArrayList<>();
int countAfter = countBefore;
//Batch the required number together
int countSoFar = 0;
List<MultiDataSet> tempToMerge = new ArrayList<>();
while (tempList.size() > 0 && countSoFar != minibatchSize) {
MultiDataSet next = tempList.removeFirst();
if (countSoFar + next.getFeatures(0).size(0) <= minibatchSize) {
//Add the entire DataSet object
tempToMerge.add(next);
countSoFar += next.getFeatures(0).size(0);
} else {
//Split the DataSet
List<MultiDataSet> examples = next.asList();
for (MultiDataSet ds : examples) {
tempList.addFirst(ds);
}
}
}
//At this point: we should have the required number of examples in tempToMerge (unless it's a final export)
MultiDataSet toExport = org.nd4j.linalg.dataset.MultiDataSet.merge(tempToMerge);
exportPaths.add(export(toExport, partitionIdx, countAfter++));
return new Pair<>(countAfter, exportPaths);
}
use of org.nd4j.linalg.dataset.api.MultiDataSet in project deeplearning4j by deeplearning4j.
the class TestSparkComputationGraph method testBasic.
@Test
public void testBasic() throws Exception {
JavaSparkContext sc = this.sc;
RecordReader rr = new CSVRecordReader(0, ",");
rr.initialize(new FileSplit(new ClassPathResource("iris.txt").getTempFileFromArchive()));
MultiDataSetIterator iter = new RecordReaderMultiDataSetIterator.Builder(1).addReader("iris", rr).addInput("iris", 0, 3).addOutputOneHot("iris", 4, 3).build();
List<MultiDataSet> list = new ArrayList<>(150);
while (iter.hasNext()) list.add(iter.next());
ComputationGraphConfiguration config = new NeuralNetConfiguration.Builder().optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).learningRate(0.1).graphBuilder().addInputs("in").addLayer("dense", new DenseLayer.Builder().nIn(4).nOut(2).build(), "in").addLayer("out", new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT).nIn(2).nOut(3).build(), "dense").setOutputs("out").pretrain(false).backprop(true).build();
ComputationGraph cg = new ComputationGraph(config);
cg.init();
TrainingMaster tm = new ParameterAveragingTrainingMaster(true, numExecutors(), 1, 10, 1, 0);
SparkComputationGraph scg = new SparkComputationGraph(sc, cg, tm);
scg.setListeners(Collections.singleton((IterationListener) new ScoreIterationListener(1)));
JavaRDD<MultiDataSet> rdd = sc.parallelize(list);
scg.fitMultiDataSet(rdd);
//Try: fitting using DataSet
DataSetIterator iris = new IrisDataSetIterator(1, 150);
List<DataSet> list2 = new ArrayList<>();
while (iris.hasNext()) list2.add(iris.next());
JavaRDD<DataSet> rddDS = sc.parallelize(list2);
scg.fit(rddDS);
}
use of org.nd4j.linalg.dataset.api.MultiDataSet in project deeplearning4j by deeplearning4j.
the class RecordReaderMultiDataSetIteratorTest method testSplittingCSVSequence.
@Test
public void testSplittingCSVSequence() throws Exception {
//need to manually extract
for (int i = 0; i < 3; i++) {
new ClassPathResource(String.format("csvsequence_%d.txt", i)).getTempFileFromArchive();
new ClassPathResource(String.format("csvsequencelabels_%d.txt", i)).getTempFileFromArchive();
new ClassPathResource(String.format("csvsequencelabelsShort_%d.txt", i)).getTempFileFromArchive();
}
ClassPathResource resource = new ClassPathResource("csvsequence_0.txt");
String featuresPath = resource.getTempFileFromArchive().getAbsolutePath().replaceAll("0", "%d");
resource = new ClassPathResource("csvsequencelabels_0.txt");
String labelsPath = resource.getTempFileFromArchive().getAbsolutePath().replaceAll("0", "%d");
SequenceRecordReader featureReader = new CSVSequenceRecordReader(1, ",");
SequenceRecordReader labelReader = new CSVSequenceRecordReader(1, ",");
featureReader.initialize(new NumberedFileInputSplit(featuresPath, 0, 2));
labelReader.initialize(new NumberedFileInputSplit(labelsPath, 0, 2));
SequenceRecordReaderDataSetIterator iter = new SequenceRecordReaderDataSetIterator(featureReader, labelReader, 1, 4, false);
SequenceRecordReader featureReader2 = new CSVSequenceRecordReader(1, ",");
SequenceRecordReader labelReader2 = new CSVSequenceRecordReader(1, ",");
featureReader2.initialize(new NumberedFileInputSplit(featuresPath, 0, 2));
labelReader2.initialize(new NumberedFileInputSplit(labelsPath, 0, 2));
MultiDataSetIterator srrmdsi = new RecordReaderMultiDataSetIterator.Builder(1).addSequenceReader("seq1", featureReader2).addSequenceReader("seq2", labelReader2).addInput("seq1", 0, 1).addInput("seq1", 2, 2).addOutputOneHot("seq2", 0, 4).build();
while (iter.hasNext()) {
DataSet ds = iter.next();
INDArray fds = ds.getFeatureMatrix();
INDArray lds = ds.getLabels();
MultiDataSet mds = srrmdsi.next();
assertEquals(2, mds.getFeatures().length);
assertEquals(1, mds.getLabels().length);
assertNull(mds.getFeaturesMaskArrays());
assertNull(mds.getLabelsMaskArrays());
INDArray[] fmds = mds.getFeatures();
INDArray[] lmds = mds.getLabels();
assertNotNull(fmds);
assertNotNull(lmds);
for (int i = 0; i < fmds.length; i++) assertNotNull(fmds[i]);
for (int i = 0; i < lmds.length; i++) assertNotNull(lmds[i]);
INDArray expIn1 = fds.get(NDArrayIndex.all(), NDArrayIndex.interval(0, 1, true), NDArrayIndex.all());
INDArray expIn2 = fds.get(NDArrayIndex.all(), NDArrayIndex.interval(2, 2, true), NDArrayIndex.all());
assertEquals(expIn1, fmds[0]);
assertEquals(expIn2, fmds[1]);
assertEquals(lds, lmds[0]);
}
assertFalse(srrmdsi.hasNext());
}
Aggregations