use of org.nd4j.linalg.dataset.api.MultiDataSet in project deeplearning4j by deeplearning4j.
the class ComputationGraph method fit.
/**
* Fit the ComputationGraph using a MultiDataSetIterator
*/
public void fit(MultiDataSetIterator multi) {
if (flattenedGradients == null)
initGradientsView();
MultiDataSetIterator multiDataSetIterator;
if (multi.asyncSupported()) {
multiDataSetIterator = new AsyncMultiDataSetIterator(multi, 2);
} else
multiDataSetIterator = multi;
if (configuration.isPretrain()) {
pretrain(multiDataSetIterator);
}
if (configuration.isBackprop()) {
while (multiDataSetIterator.hasNext()) {
MultiDataSet next = multiDataSetIterator.next();
if (next.getFeatures() == null || next.getLabels() == null)
break;
if (configuration.getBackpropType() == BackpropType.TruncatedBPTT) {
doTruncatedBPTT(next.getFeatures(), next.getLabels(), next.getFeaturesMaskArrays(), next.getLabelsMaskArrays());
} else {
boolean hasMaskArrays = next.hasMaskArrays();
if (hasMaskArrays) {
setLayerMaskArrays(next.getFeaturesMaskArrays(), next.getLabelsMaskArrays());
}
setInputs(next.getFeatures());
setLabels(next.getLabels());
if (solver == null) {
solver = new Solver.Builder().configure(defaultConfiguration).listeners(listeners).model(this).build();
}
solver.optimize();
if (hasMaskArrays) {
clearLayerMaskArrays();
}
}
Nd4j.getMemoryManager().invokeGcOccasionally();
}
}
}
use of org.nd4j.linalg.dataset.api.MultiDataSet in project deeplearning4j by deeplearning4j.
the class ComputationGraphUtil method toMultiDataSet.
/** Convert a DataSet to the equivalent MultiDataSet */
public static MultiDataSet toMultiDataSet(DataSet dataSet) {
INDArray f = dataSet.getFeatures();
INDArray l = dataSet.getLabels();
INDArray fMask = dataSet.getFeaturesMaskArray();
INDArray lMask = dataSet.getLabelsMaskArray();
INDArray[] fNew = f == null ? null : new INDArray[] { f };
INDArray[] lNew = l == null ? null : new INDArray[] { l };
INDArray[] fMaskNew = (fMask != null ? new INDArray[] { fMask } : null);
INDArray[] lMaskNew = (lMask != null ? new INDArray[] { lMask } : null);
return new org.nd4j.linalg.dataset.MultiDataSet(fNew, lNew, fMaskNew, lMaskNew);
}
use of org.nd4j.linalg.dataset.api.MultiDataSet in project deeplearning4j by deeplearning4j.
the class RecordReaderMultiDataSetIterator method nextMultiDataSet.
private MultiDataSet nextMultiDataSet(Map<String, List<List<Writable>>> nextRRVals, Map<String, List<List<List<Writable>>>> nextSeqRRVals, List<RecordMetaDataComposableMap> nextMetas) {
int minExamples = Integer.MAX_VALUE;
for (List<List<Writable>> exampleData : nextRRVals.values()) {
minExamples = Math.min(minExamples, exampleData.size());
}
for (List<List<List<Writable>>> exampleData : nextSeqRRVals.values()) {
minExamples = Math.min(minExamples, exampleData.size());
}
if (minExamples == Integer.MAX_VALUE)
//Should never happen
throw new RuntimeException("Error occurred during data set generation: no readers?");
//In order to align data at the end (for each example individually), we need to know the length of the
// longest time series for each example
int[] longestSequence = null;
if (alignmentMode == AlignmentMode.ALIGN_END) {
longestSequence = new int[minExamples];
for (Map.Entry<String, List<List<List<Writable>>>> entry : nextSeqRRVals.entrySet()) {
List<List<List<Writable>>> list = entry.getValue();
for (int i = 0; i < list.size() && i < minExamples; i++) {
longestSequence[i] = Math.max(longestSequence[i], list.get(i).size());
}
}
}
//Second: create the input arrays
//To do this, we need to know longest time series length, so we can do padding
int longestTS = -1;
if (alignmentMode != AlignmentMode.EQUAL_LENGTH) {
for (Map.Entry<String, List<List<List<Writable>>>> entry : nextSeqRRVals.entrySet()) {
List<List<List<Writable>>> list = entry.getValue();
for (List<List<Writable>> c : list) {
longestTS = Math.max(longestTS, c.size());
}
}
}
INDArray[] inputArrs = new INDArray[inputs.size()];
INDArray[] inputArrMasks = new INDArray[inputs.size()];
boolean inputMasks = false;
int i = 0;
for (SubsetDetails d : inputs) {
if (nextRRVals.containsKey(d.readerName)) {
//Standard reader
List<List<Writable>> list = nextRRVals.get(d.readerName);
inputArrs[i] = convertWritables(list, minExamples, d);
} else {
//Sequence reader
List<List<List<Writable>>> list = nextSeqRRVals.get(d.readerName);
Pair<INDArray, INDArray> p = convertWritablesSequence(list, minExamples, longestTS, d, longestSequence);
inputArrs[i] = p.getFirst();
inputArrMasks[i] = p.getSecond();
if (inputArrMasks[i] != null)
inputMasks = true;
}
i++;
}
if (!inputMasks)
inputArrMasks = null;
//Third: create the outputs
INDArray[] outputArrs = new INDArray[outputs.size()];
INDArray[] outputArrMasks = new INDArray[outputs.size()];
boolean outputMasks = false;
i = 0;
for (SubsetDetails d : outputs) {
if (nextRRVals.containsKey(d.readerName)) {
//Standard reader
List<List<Writable>> list = nextRRVals.get(d.readerName);
outputArrs[i] = convertWritables(list, minExamples, d);
} else {
//Sequence reader
List<List<List<Writable>>> list = nextSeqRRVals.get(d.readerName);
Pair<INDArray, INDArray> p = convertWritablesSequence(list, minExamples, longestTS, d, longestSequence);
outputArrs[i] = p.getFirst();
outputArrMasks[i] = p.getSecond();
if (outputArrMasks[i] != null)
outputMasks = true;
}
i++;
}
if (!outputMasks)
outputArrMasks = null;
MultiDataSet mds = new org.nd4j.linalg.dataset.MultiDataSet(inputArrs, outputArrs, inputArrMasks, outputArrMasks);
if (collectMetaData) {
mds.setExampleMetaData(nextMetas);
}
if (preProcessor != null)
preProcessor.preProcess(mds);
return mds;
}
use of org.nd4j.linalg.dataset.api.MultiDataSet in project deeplearning4j by deeplearning4j.
the class RecordReaderMultiDataSetIteratorTest method testImagesRRDMSI.
@Test
public void testImagesRRDMSI() throws Exception {
File parentDir = Files.createTempDir();
parentDir.deleteOnExit();
String str1 = FilenameUtils.concat(parentDir.getAbsolutePath(), "Zico/");
String str2 = FilenameUtils.concat(parentDir.getAbsolutePath(), "Ziwang_Xu/");
File f1 = new File(str1);
File f2 = new File(str2);
f1.mkdirs();
f2.mkdirs();
writeStreamToFile(new File(FilenameUtils.concat(f1.getPath(), "Zico_0001.jpg")), new ClassPathResource("lfwtest/Zico/Zico_0001.jpg").getInputStream());
writeStreamToFile(new File(FilenameUtils.concat(f2.getPath(), "Ziwang_Xu_0001.jpg")), new ClassPathResource("lfwtest/Ziwang_Xu/Ziwang_Xu_0001.jpg").getInputStream());
int outputNum = 2;
Random r = new Random(12345);
ParentPathLabelGenerator labelMaker = new ParentPathLabelGenerator();
ImageRecordReader rr1 = new ImageRecordReader(10, 10, 1, labelMaker);
ImageRecordReader rr1s = new ImageRecordReader(5, 5, 1, labelMaker);
rr1.initialize(new FileSplit(parentDir));
rr1s.initialize(new FileSplit(parentDir));
MultiDataSetIterator trainDataIterator = new RecordReaderMultiDataSetIterator.Builder(1).addReader("rr1", rr1).addReader("rr1s", rr1s).addInput("rr1", 0, 0).addInput("rr1s", 0, 0).addOutputOneHot("rr1s", 1, outputNum).build();
//Now, do the same thing with ImageRecordReader, and check we get the same results:
ImageRecordReader rr1_b = new ImageRecordReader(10, 10, 1, labelMaker);
ImageRecordReader rr1s_b = new ImageRecordReader(5, 5, 1, labelMaker);
rr1_b.initialize(new FileSplit(parentDir));
rr1s_b.initialize(new FileSplit(parentDir));
DataSetIterator dsi1 = new RecordReaderDataSetIterator(rr1_b, 1, 1, 2);
DataSetIterator dsi2 = new RecordReaderDataSetIterator(rr1s_b, 1, 1, 2);
for (int i = 0; i < 2; i++) {
MultiDataSet mds = trainDataIterator.next();
DataSet d1 = dsi1.next();
DataSet d2 = dsi2.next();
assertEquals(d1.getFeatureMatrix(), mds.getFeatures(0));
assertEquals(d2.getFeatureMatrix(), mds.getFeatures(1));
assertEquals(d1.getLabels(), mds.getLabels(0));
}
}
use of org.nd4j.linalg.dataset.api.MultiDataSet in project deeplearning4j by deeplearning4j.
the class RecordReaderMultiDataSetIteratorTest method testSplittingCSVMeta.
@Test
public void testSplittingCSVMeta() throws Exception {
//Here's the idea: take Iris, and split it up into 2 inputs and 2 output arrays
//Inputs: columns 0 and 1-2
//Outputs: columns 3, and 4->OneHot
RecordReader rr2 = new CSVRecordReader(0, ",");
rr2.initialize(new FileSplit(new ClassPathResource("iris.txt").getTempFileFromArchive()));
RecordReaderMultiDataSetIterator rrmdsi = new RecordReaderMultiDataSetIterator.Builder(10).addReader("reader", rr2).addInput("reader", 0, 0).addInput("reader", 1, 2).addOutput("reader", 3, 3).addOutputOneHot("reader", 4, 3).build();
rrmdsi.setCollectMetaData(true);
int count = 0;
while (rrmdsi.hasNext()) {
MultiDataSet mds = rrmdsi.next();
MultiDataSet fromMeta = rrmdsi.loadFromMetaData(mds.getExampleMetaData(RecordMetaData.class));
assertEquals(mds, fromMeta);
count++;
}
assertEquals(150 / 10, count);
}
Aggregations