use of org.datavec.api.records.metadata.RecordMetaDataComposableMap in project deeplearning4j by deeplearning4j.
the class RecordReaderMultiDataSetIterator method loadFromMetaData.
/**
* Load a multiple sequence examples to a DataSet, using the provided RecordMetaData instances.
*
* @param list List of RecordMetaData instances to load from. Should have been produced by the record reader provided
* to the SequenceRecordReaderDataSetIterator constructor
* @return DataSet with the specified examples
* @throws IOException If an error occurs during loading of the data
*/
public MultiDataSet loadFromMetaData(List<RecordMetaData> list) throws IOException {
//First: load the next values from the RR / SeqRRs
Map<String, List<List<Writable>>> nextRRVals = new HashMap<>();
Map<String, List<List<List<Writable>>>> nextSeqRRVals = new HashMap<>();
List<RecordMetaDataComposableMap> nextMetas = (collectMetaData ? new ArrayList<RecordMetaDataComposableMap>() : null);
for (Map.Entry<String, RecordReader> entry : recordReaders.entrySet()) {
RecordReader rr = entry.getValue();
List<RecordMetaData> thisRRMeta = new ArrayList<>();
for (RecordMetaData m : list) {
RecordMetaDataComposableMap m2 = (RecordMetaDataComposableMap) m;
thisRRMeta.add(m2.getMeta().get(entry.getKey()));
}
List<Record> fromMeta = rr.loadFromMetaData(thisRRMeta);
List<List<Writable>> writables = new ArrayList<>(list.size());
for (Record r : fromMeta) {
writables.add(r.getRecord());
}
nextRRVals.put(entry.getKey(), writables);
}
for (Map.Entry<String, SequenceRecordReader> entry : sequenceRecordReaders.entrySet()) {
SequenceRecordReader rr = entry.getValue();
List<RecordMetaData> thisRRMeta = new ArrayList<>();
for (RecordMetaData m : list) {
RecordMetaDataComposableMap m2 = (RecordMetaDataComposableMap) m;
thisRRMeta.add(m2.getMeta().get(entry.getKey()));
}
List<SequenceRecord> fromMeta = rr.loadSequenceFromMetaData(thisRRMeta);
List<List<List<Writable>>> writables = new ArrayList<>(list.size());
for (SequenceRecord r : fromMeta) {
writables.add(r.getSequenceRecord());
}
nextSeqRRVals.put(entry.getKey(), writables);
}
return nextMultiDataSet(nextRRVals, nextSeqRRVals, nextMetas);
}
use of org.datavec.api.records.metadata.RecordMetaDataComposableMap in project deeplearning4j by deeplearning4j.
the class RecordReaderMultiDataSetIterator method nextMultiDataSet.
private MultiDataSet nextMultiDataSet(Map<String, List<List<Writable>>> nextRRVals, Map<String, List<List<List<Writable>>>> nextSeqRRVals, List<RecordMetaDataComposableMap> nextMetas) {
int minExamples = Integer.MAX_VALUE;
for (List<List<Writable>> exampleData : nextRRVals.values()) {
minExamples = Math.min(minExamples, exampleData.size());
}
for (List<List<List<Writable>>> exampleData : nextSeqRRVals.values()) {
minExamples = Math.min(minExamples, exampleData.size());
}
if (minExamples == Integer.MAX_VALUE)
//Should never happen
throw new RuntimeException("Error occurred during data set generation: no readers?");
//In order to align data at the end (for each example individually), we need to know the length of the
// longest time series for each example
int[] longestSequence = null;
if (alignmentMode == AlignmentMode.ALIGN_END) {
longestSequence = new int[minExamples];
for (Map.Entry<String, List<List<List<Writable>>>> entry : nextSeqRRVals.entrySet()) {
List<List<List<Writable>>> list = entry.getValue();
for (int i = 0; i < list.size() && i < minExamples; i++) {
longestSequence[i] = Math.max(longestSequence[i], list.get(i).size());
}
}
}
//Second: create the input arrays
//To do this, we need to know longest time series length, so we can do padding
int longestTS = -1;
if (alignmentMode != AlignmentMode.EQUAL_LENGTH) {
for (Map.Entry<String, List<List<List<Writable>>>> entry : nextSeqRRVals.entrySet()) {
List<List<List<Writable>>> list = entry.getValue();
for (List<List<Writable>> c : list) {
longestTS = Math.max(longestTS, c.size());
}
}
}
INDArray[] inputArrs = new INDArray[inputs.size()];
INDArray[] inputArrMasks = new INDArray[inputs.size()];
boolean inputMasks = false;
int i = 0;
for (SubsetDetails d : inputs) {
if (nextRRVals.containsKey(d.readerName)) {
//Standard reader
List<List<Writable>> list = nextRRVals.get(d.readerName);
inputArrs[i] = convertWritables(list, minExamples, d);
} else {
//Sequence reader
List<List<List<Writable>>> list = nextSeqRRVals.get(d.readerName);
Pair<INDArray, INDArray> p = convertWritablesSequence(list, minExamples, longestTS, d, longestSequence);
inputArrs[i] = p.getFirst();
inputArrMasks[i] = p.getSecond();
if (inputArrMasks[i] != null)
inputMasks = true;
}
i++;
}
if (!inputMasks)
inputArrMasks = null;
//Third: create the outputs
INDArray[] outputArrs = new INDArray[outputs.size()];
INDArray[] outputArrMasks = new INDArray[outputs.size()];
boolean outputMasks = false;
i = 0;
for (SubsetDetails d : outputs) {
if (nextRRVals.containsKey(d.readerName)) {
//Standard reader
List<List<Writable>> list = nextRRVals.get(d.readerName);
outputArrs[i] = convertWritables(list, minExamples, d);
} else {
//Sequence reader
List<List<List<Writable>>> list = nextSeqRRVals.get(d.readerName);
Pair<INDArray, INDArray> p = convertWritablesSequence(list, minExamples, longestTS, d, longestSequence);
outputArrs[i] = p.getFirst();
outputArrMasks[i] = p.getSecond();
if (outputArrMasks[i] != null)
outputMasks = true;
}
i++;
}
if (!outputMasks)
outputArrMasks = null;
MultiDataSet mds = new org.nd4j.linalg.dataset.MultiDataSet(inputArrs, outputArrs, inputArrMasks, outputArrMasks);
if (collectMetaData) {
mds.setExampleMetaData(nextMetas);
}
if (preProcessor != null)
preProcessor.preProcess(mds);
return mds;
}
use of org.datavec.api.records.metadata.RecordMetaDataComposableMap in project deeplearning4j by deeplearning4j.
the class RecordReaderMultiDataSetIterator method next.
@Override
public MultiDataSet next(int num) {
if (!hasNext())
throw new NoSuchElementException("No next elements");
//First: load the next values from the RR / SeqRRs
Map<String, List<List<Writable>>> nextRRVals = new HashMap<>();
Map<String, List<List<List<Writable>>>> nextSeqRRVals = new HashMap<>();
List<RecordMetaDataComposableMap> nextMetas = (collectMetaData ? new ArrayList<RecordMetaDataComposableMap>() : null);
for (Map.Entry<String, RecordReader> entry : recordReaders.entrySet()) {
RecordReader rr = entry.getValue();
List<List<Writable>> writables = new ArrayList<>(num);
for (int i = 0; i < num && rr.hasNext(); i++) {
List<Writable> record;
if (collectMetaData) {
Record r = rr.nextRecord();
record = r.getRecord();
if (nextMetas.size() <= i) {
nextMetas.add(new RecordMetaDataComposableMap(new HashMap<String, RecordMetaData>()));
}
RecordMetaDataComposableMap map = nextMetas.get(i);
map.getMeta().put(entry.getKey(), r.getMetaData());
} else {
record = rr.next();
}
writables.add(record);
}
nextRRVals.put(entry.getKey(), writables);
}
for (Map.Entry<String, SequenceRecordReader> entry : sequenceRecordReaders.entrySet()) {
SequenceRecordReader rr = entry.getValue();
List<List<List<Writable>>> writables = new ArrayList<>(num);
for (int i = 0; i < num && rr.hasNext(); i++) {
List<List<Writable>> sequence;
if (collectMetaData) {
SequenceRecord r = rr.nextSequence();
sequence = r.getSequenceRecord();
if (nextMetas.size() <= i) {
nextMetas.add(new RecordMetaDataComposableMap(new HashMap<String, RecordMetaData>()));
}
RecordMetaDataComposableMap map = nextMetas.get(i);
map.getMeta().put(entry.getKey(), r.getMetaData());
} else {
sequence = rr.sequenceRecord();
}
writables.add(sequence);
}
nextSeqRRVals.put(entry.getKey(), writables);
}
return nextMultiDataSet(nextRRVals, nextSeqRRVals, nextMetas);
}
Aggregations