use of org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator in project deeplearning4j by deeplearning4j.
the class RecordReaderMultiDataSetIteratorTest method testInputValidation.
@Test
public void testInputValidation() {
//Test: no readers
try {
MultiDataSetIterator r = new RecordReaderMultiDataSetIterator.Builder(1).addInput("something").addOutput("something").build();
fail("Should have thrown exception");
} catch (Exception e) {
}
//Test: reference to reader that doesn't exist
try {
RecordReader rr = new CSVRecordReader(0, ",");
rr.initialize(new FileSplit(new ClassPathResource("iris.txt").getTempFileFromArchive()));
MultiDataSetIterator r = new RecordReaderMultiDataSetIterator.Builder(1).addReader("iris", rr).addInput("thisDoesntExist", 0, 3).addOutputOneHot("iris", 4, 3).build();
fail("Should have thrown exception");
} catch (Exception e) {
}
//Test: no inputs or outputs
try {
RecordReader rr = new CSVRecordReader(0, ",");
rr.initialize(new FileSplit(new ClassPathResource("iris.txt").getTempFileFromArchive()));
MultiDataSetIterator r = new RecordReaderMultiDataSetIterator.Builder(1).addReader("iris", rr).build();
fail("Should have thrown exception");
} catch (Exception e) {
}
}
use of org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator in project deeplearning4j by deeplearning4j.
the class RecordReaderMultiDataSetIteratorTest method testsBasic.
@Test
public void testsBasic() throws Exception {
//Load details from CSV files; single input/output -> compare to RecordReaderDataSetIterator
RecordReader rr = new CSVRecordReader(0, ",");
rr.initialize(new FileSplit(new ClassPathResource("iris.txt").getTempFileFromArchive()));
RecordReaderDataSetIterator rrdsi = new RecordReaderDataSetIterator(rr, 10, 4, 3);
RecordReader rr2 = new CSVRecordReader(0, ",");
rr2.initialize(new FileSplit(new ClassPathResource("iris.txt").getTempFileFromArchive()));
MultiDataSetIterator rrmdsi = new RecordReaderMultiDataSetIterator.Builder(10).addReader("reader", rr2).addInput("reader", 0, 3).addOutputOneHot("reader", 4, 3).build();
while (rrdsi.hasNext()) {
DataSet ds = rrdsi.next();
INDArray fds = ds.getFeatureMatrix();
INDArray lds = ds.getLabels();
MultiDataSet mds = rrmdsi.next();
assertEquals(1, mds.getFeatures().length);
assertEquals(1, mds.getLabels().length);
assertNull(mds.getFeaturesMaskArrays());
assertNull(mds.getLabelsMaskArrays());
INDArray fmds = mds.getFeatures(0);
INDArray lmds = mds.getLabels(0);
assertNotNull(fmds);
assertNotNull(lmds);
assertEquals(fds, fmds);
assertEquals(lds, lmds);
}
assertFalse(rrmdsi.hasNext());
//need to manually extract
for (int i = 0; i < 3; i++) {
new ClassPathResource(String.format("csvsequence_%d.txt", i)).getTempFileFromArchive();
new ClassPathResource(String.format("csvsequencelabels_%d.txt", i)).getTempFileFromArchive();
new ClassPathResource(String.format("csvsequencelabelsShort_%d.txt", i)).getTempFileFromArchive();
}
//Load time series from CSV sequence files; compare to SequenceRecordReaderDataSetIterator
ClassPathResource resource = new ClassPathResource("csvsequence_0.txt");
String featuresPath = resource.getTempFileFromArchive().getAbsolutePath().replaceAll("0", "%d");
resource = new ClassPathResource("csvsequencelabels_0.txt");
String labelsPath = resource.getTempFileFromArchive().getAbsolutePath().replaceAll("0", "%d");
SequenceRecordReader featureReader = new CSVSequenceRecordReader(1, ",");
SequenceRecordReader labelReader = new CSVSequenceRecordReader(1, ",");
featureReader.initialize(new NumberedFileInputSplit(featuresPath, 0, 2));
labelReader.initialize(new NumberedFileInputSplit(labelsPath, 0, 2));
SequenceRecordReaderDataSetIterator iter = new SequenceRecordReaderDataSetIterator(featureReader, labelReader, 1, 4, false);
SequenceRecordReader featureReader2 = new CSVSequenceRecordReader(1, ",");
SequenceRecordReader labelReader2 = new CSVSequenceRecordReader(1, ",");
featureReader2.initialize(new NumberedFileInputSplit(featuresPath, 0, 2));
labelReader2.initialize(new NumberedFileInputSplit(labelsPath, 0, 2));
MultiDataSetIterator srrmdsi = new RecordReaderMultiDataSetIterator.Builder(1).addSequenceReader("in", featureReader2).addSequenceReader("out", labelReader2).addInput("in").addOutputOneHot("out", 0, 4).build();
while (iter.hasNext()) {
DataSet ds = iter.next();
INDArray fds = ds.getFeatureMatrix();
INDArray lds = ds.getLabels();
MultiDataSet mds = srrmdsi.next();
assertEquals(1, mds.getFeatures().length);
assertEquals(1, mds.getLabels().length);
assertNull(mds.getFeaturesMaskArrays());
assertNull(mds.getLabelsMaskArrays());
INDArray fmds = mds.getFeatures(0);
INDArray lmds = mds.getLabels(0);
assertNotNull(fmds);
assertNotNull(lmds);
assertEquals(fds, fmds);
assertEquals(lds, lmds);
}
assertFalse(srrmdsi.hasNext());
}
use of org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator in project deeplearning4j by deeplearning4j.
the class TestComputationGraphNetwork method testIrisFitMultiDataSetIterator.
@Test
public void testIrisFitMultiDataSetIterator() throws Exception {
RecordReader rr = new CSVRecordReader(0, ",");
rr.initialize(new FileSplit(new ClassPathResource("iris.txt").getTempFileFromArchive()));
MultiDataSetIterator iter = new RecordReaderMultiDataSetIterator.Builder(10).addReader("iris", rr).addInput("iris", 0, 3).addOutputOneHot("iris", 4, 3).build();
ComputationGraphConfiguration config = new NeuralNetConfiguration.Builder().optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).learningRate(0.1).graphBuilder().addInputs("in").addLayer("dense", new DenseLayer.Builder().nIn(4).nOut(2).build(), "in").addLayer("out", new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT).nIn(2).nOut(3).build(), "dense").setOutputs("out").pretrain(false).backprop(true).build();
ComputationGraph cg = new ComputationGraph(config);
cg.init();
cg.fit(iter);
rr.reset();
iter = new RecordReaderMultiDataSetIterator.Builder(10).addReader("iris", rr).addInput("iris", 0, 3).addOutputOneHot("iris", 4, 3).build();
while (iter.hasNext()) {
cg.fit(iter.next());
}
}
use of org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator in project deeplearning4j by deeplearning4j.
the class ComputationGraph method fit.
/**
* Fit the ComputationGraph using the specified inputs and labels (and mask arrays)
*
* @param inputs The network inputs (features)
* @param labels The network labels
* @param featureMaskArrays Mask arrays for inputs/features. Typically used for RNN training. May be null.
* @param labelMaskArrays Mas arrays for the labels/outputs. Typically used for RNN training. May be null.
*/
public void fit(INDArray[] inputs, INDArray[] labels, INDArray[] featureMaskArrays, INDArray[] labelMaskArrays) {
if (flattenedGradients == null)
initGradientsView();
setInputs(inputs);
setLabels(labels);
setLayerMaskArrays(featureMaskArrays, labelMaskArrays);
update(TaskUtils.buildTask(inputs, labels));
if (configuration.isPretrain()) {
MultiDataSetIterator iter = new SingletonMultiDataSetIterator(new org.nd4j.linalg.dataset.MultiDataSet(inputs, labels, featureMaskArrays, labelMaskArrays));
pretrain(iter);
}
if (configuration.isBackprop()) {
if (configuration.getBackpropType() == BackpropType.TruncatedBPTT) {
doTruncatedBPTT(inputs, labels, featureMaskArrays, labelMaskArrays);
} else {
if (solver == null) {
solver = new Solver.Builder().configure(conf()).listeners(getListeners()).model(this).build();
}
solver.optimize();
}
}
if (featureMaskArrays != null || labelMaskArrays != null) {
clearLayerMaskArrays();
}
}
use of org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator in project deeplearning4j by deeplearning4j.
the class ParallelWrapper method fit.
/**
*
* @param source
*/
public synchronized void fit(@NonNull MultiDataSetIterator source) {
stopFit.set(false);
if (zoo == null) {
zoo = new Trainer[workers];
for (int cnt = 0; cnt < workers; cnt++) {
// we pass true here, to tell Trainer to use MultiDataSet queue for training
zoo[cnt] = new Trainer(cnt, model, Nd4j.getAffinityManager().getDeviceForCurrentThread(), true);
zoo[cnt].setUncaughtExceptionHandler(handler);
zoo[cnt].start();
}
} else {
for (int cnt = 0; cnt < workers; cnt++) {
zoo[cnt].useMDS = true;
}
}
source.reset();
MultiDataSetIterator iterator;
if (prefetchSize > 0 && source.asyncSupported()) {
iterator = new AsyncMultiDataSetIterator(source, prefetchSize);
} else
iterator = source;
AtomicInteger locker = new AtomicInteger(0);
while (iterator.hasNext() && !stopFit.get()) {
MultiDataSet dataSet = iterator.next();
if (dataSet == null)
throw new ND4JIllegalStateException("You can't have NULL as MultiDataSet");
/*
now dataSet should be dispatched to next free workers, until all workers are busy. And then we should block till all finished.
*/
int pos = locker.getAndIncrement();
zoo[pos].feedMultiDataSet(dataSet);
/*
if all workers are dispatched now, join till all are finished
*/
if (pos + 1 == workers || !iterator.hasNext()) {
iterationsCounter.incrementAndGet();
for (int cnt = 0; cnt < workers && cnt < locker.get(); cnt++) {
try {
zoo[cnt].waitTillRunning();
} catch (Exception e) {
throw new RuntimeException(e);
}
}
Nd4j.getMemoryManager().invokeGcOccasionally();
/*
average model, and propagate it to whole
*/
if (iterationsCounter.get() % averagingFrequency == 0 && pos + 1 == workers) {
double score = getScore(locker);
// averaging updaters state
if (model instanceof ComputationGraph) {
averageUpdatersState(locker, score);
} else
throw new RuntimeException("MultiDataSet must only be used with ComputationGraph model");
if (legacyAveraging && Nd4j.getAffinityManager().getNumberOfDevices() > 1) {
for (int cnt = 0; cnt < workers; cnt++) {
zoo[cnt].updateModel(model);
}
}
}
locker.set(0);
}
}
// sanity checks, or the dataset may never average
if (!wasAveraged)
log.warn("Parameters were never averaged on current fit(). Ratios of batch size, num workers, and averaging frequency may be responsible.");
// throw new IllegalStateException("Parameters were never averaged. Please check batch size ratios, number of workers, and your averaging frequency.");
log.debug("Iterations passed: {}", iterationsCounter.get());
// iterationsCounter.set(0);
}
Aggregations