Search in sources :

Example 1 with RecordMetaDataComposableMap

use of org.datavec.api.records.metadata.RecordMetaDataComposableMap in project deeplearning4j by deeplearning4j.

the class RecordReaderMultiDataSetIterator method loadFromMetaData.

/**
     * Load a multiple sequence examples to a DataSet, using the provided RecordMetaData instances.
     *
     * @param list List of RecordMetaData instances to load from. Should have been produced by the record reader provided
     *             to the SequenceRecordReaderDataSetIterator constructor
     * @return DataSet with the specified examples
     * @throws IOException If an error occurs during loading of the data
     */
public MultiDataSet loadFromMetaData(List<RecordMetaData> list) throws IOException {
    //First: load the next values from the RR / SeqRRs
    Map<String, List<List<Writable>>> nextRRVals = new HashMap<>();
    Map<String, List<List<List<Writable>>>> nextSeqRRVals = new HashMap<>();
    List<RecordMetaDataComposableMap> nextMetas = (collectMetaData ? new ArrayList<RecordMetaDataComposableMap>() : null);
    for (Map.Entry<String, RecordReader> entry : recordReaders.entrySet()) {
        RecordReader rr = entry.getValue();
        List<RecordMetaData> thisRRMeta = new ArrayList<>();
        for (RecordMetaData m : list) {
            RecordMetaDataComposableMap m2 = (RecordMetaDataComposableMap) m;
            thisRRMeta.add(m2.getMeta().get(entry.getKey()));
        }
        List<Record> fromMeta = rr.loadFromMetaData(thisRRMeta);
        List<List<Writable>> writables = new ArrayList<>(list.size());
        for (Record r : fromMeta) {
            writables.add(r.getRecord());
        }
        nextRRVals.put(entry.getKey(), writables);
    }
    for (Map.Entry<String, SequenceRecordReader> entry : sequenceRecordReaders.entrySet()) {
        SequenceRecordReader rr = entry.getValue();
        List<RecordMetaData> thisRRMeta = new ArrayList<>();
        for (RecordMetaData m : list) {
            RecordMetaDataComposableMap m2 = (RecordMetaDataComposableMap) m;
            thisRRMeta.add(m2.getMeta().get(entry.getKey()));
        }
        List<SequenceRecord> fromMeta = rr.loadSequenceFromMetaData(thisRRMeta);
        List<List<List<Writable>>> writables = new ArrayList<>(list.size());
        for (SequenceRecord r : fromMeta) {
            writables.add(r.getSequenceRecord());
        }
        nextSeqRRVals.put(entry.getKey(), writables);
    }
    return nextMultiDataSet(nextRRVals, nextSeqRRVals, nextMetas);
}
Also used : RecordMetaData(org.datavec.api.records.metadata.RecordMetaData) SequenceRecordReader(org.datavec.api.records.reader.SequenceRecordReader) RecordReader(org.datavec.api.records.reader.RecordReader) SequenceRecordReader(org.datavec.api.records.reader.SequenceRecordReader) NDArrayWritable(org.datavec.common.data.NDArrayWritable) Writable(org.datavec.api.writable.Writable) SequenceRecord(org.datavec.api.records.SequenceRecord) SequenceRecord(org.datavec.api.records.SequenceRecord) Record(org.datavec.api.records.Record) RecordMetaDataComposableMap(org.datavec.api.records.metadata.RecordMetaDataComposableMap) RecordMetaDataComposableMap(org.datavec.api.records.metadata.RecordMetaDataComposableMap)

Example 2 with RecordMetaDataComposableMap

use of org.datavec.api.records.metadata.RecordMetaDataComposableMap in project deeplearning4j by deeplearning4j.

the class RecordReaderMultiDataSetIterator method nextMultiDataSet.

private MultiDataSet nextMultiDataSet(Map<String, List<List<Writable>>> nextRRVals, Map<String, List<List<List<Writable>>>> nextSeqRRVals, List<RecordMetaDataComposableMap> nextMetas) {
    int minExamples = Integer.MAX_VALUE;
    for (List<List<Writable>> exampleData : nextRRVals.values()) {
        minExamples = Math.min(minExamples, exampleData.size());
    }
    for (List<List<List<Writable>>> exampleData : nextSeqRRVals.values()) {
        minExamples = Math.min(minExamples, exampleData.size());
    }
    if (minExamples == Integer.MAX_VALUE)
        //Should never happen
        throw new RuntimeException("Error occurred during data set generation: no readers?");
    //In order to align data at the end (for each example individually), we need to know the length of the
    // longest time series for each example
    int[] longestSequence = null;
    if (alignmentMode == AlignmentMode.ALIGN_END) {
        longestSequence = new int[minExamples];
        for (Map.Entry<String, List<List<List<Writable>>>> entry : nextSeqRRVals.entrySet()) {
            List<List<List<Writable>>> list = entry.getValue();
            for (int i = 0; i < list.size() && i < minExamples; i++) {
                longestSequence[i] = Math.max(longestSequence[i], list.get(i).size());
            }
        }
    }
    //Second: create the input arrays
    //To do this, we need to know longest time series length, so we can do padding
    int longestTS = -1;
    if (alignmentMode != AlignmentMode.EQUAL_LENGTH) {
        for (Map.Entry<String, List<List<List<Writable>>>> entry : nextSeqRRVals.entrySet()) {
            List<List<List<Writable>>> list = entry.getValue();
            for (List<List<Writable>> c : list) {
                longestTS = Math.max(longestTS, c.size());
            }
        }
    }
    INDArray[] inputArrs = new INDArray[inputs.size()];
    INDArray[] inputArrMasks = new INDArray[inputs.size()];
    boolean inputMasks = false;
    int i = 0;
    for (SubsetDetails d : inputs) {
        if (nextRRVals.containsKey(d.readerName)) {
            //Standard reader
            List<List<Writable>> list = nextRRVals.get(d.readerName);
            inputArrs[i] = convertWritables(list, minExamples, d);
        } else {
            //Sequence reader
            List<List<List<Writable>>> list = nextSeqRRVals.get(d.readerName);
            Pair<INDArray, INDArray> p = convertWritablesSequence(list, minExamples, longestTS, d, longestSequence);
            inputArrs[i] = p.getFirst();
            inputArrMasks[i] = p.getSecond();
            if (inputArrMasks[i] != null)
                inputMasks = true;
        }
        i++;
    }
    if (!inputMasks)
        inputArrMasks = null;
    //Third: create the outputs
    INDArray[] outputArrs = new INDArray[outputs.size()];
    INDArray[] outputArrMasks = new INDArray[outputs.size()];
    boolean outputMasks = false;
    i = 0;
    for (SubsetDetails d : outputs) {
        if (nextRRVals.containsKey(d.readerName)) {
            //Standard reader
            List<List<Writable>> list = nextRRVals.get(d.readerName);
            outputArrs[i] = convertWritables(list, minExamples, d);
        } else {
            //Sequence reader
            List<List<List<Writable>>> list = nextSeqRRVals.get(d.readerName);
            Pair<INDArray, INDArray> p = convertWritablesSequence(list, minExamples, longestTS, d, longestSequence);
            outputArrs[i] = p.getFirst();
            outputArrMasks[i] = p.getSecond();
            if (outputArrMasks[i] != null)
                outputMasks = true;
        }
        i++;
    }
    if (!outputMasks)
        outputArrMasks = null;
    MultiDataSet mds = new org.nd4j.linalg.dataset.MultiDataSet(inputArrs, outputArrs, inputArrMasks, outputArrMasks);
    if (collectMetaData) {
        mds.setExampleMetaData(nextMetas);
    }
    if (preProcessor != null)
        preProcessor.preProcess(mds);
    return mds;
}
Also used : NDArrayWritable(org.datavec.common.data.NDArrayWritable) Writable(org.datavec.api.writable.Writable) INDArray(org.nd4j.linalg.api.ndarray.INDArray) MultiDataSet(org.nd4j.linalg.dataset.api.MultiDataSet) RecordMetaDataComposableMap(org.datavec.api.records.metadata.RecordMetaDataComposableMap)

Example 3 with RecordMetaDataComposableMap

use of org.datavec.api.records.metadata.RecordMetaDataComposableMap in project deeplearning4j by deeplearning4j.

the class RecordReaderMultiDataSetIterator method next.

@Override
public MultiDataSet next(int num) {
    if (!hasNext())
        throw new NoSuchElementException("No next elements");
    //First: load the next values from the RR / SeqRRs
    Map<String, List<List<Writable>>> nextRRVals = new HashMap<>();
    Map<String, List<List<List<Writable>>>> nextSeqRRVals = new HashMap<>();
    List<RecordMetaDataComposableMap> nextMetas = (collectMetaData ? new ArrayList<RecordMetaDataComposableMap>() : null);
    for (Map.Entry<String, RecordReader> entry : recordReaders.entrySet()) {
        RecordReader rr = entry.getValue();
        List<List<Writable>> writables = new ArrayList<>(num);
        for (int i = 0; i < num && rr.hasNext(); i++) {
            List<Writable> record;
            if (collectMetaData) {
                Record r = rr.nextRecord();
                record = r.getRecord();
                if (nextMetas.size() <= i) {
                    nextMetas.add(new RecordMetaDataComposableMap(new HashMap<String, RecordMetaData>()));
                }
                RecordMetaDataComposableMap map = nextMetas.get(i);
                map.getMeta().put(entry.getKey(), r.getMetaData());
            } else {
                record = rr.next();
            }
            writables.add(record);
        }
        nextRRVals.put(entry.getKey(), writables);
    }
    for (Map.Entry<String, SequenceRecordReader> entry : sequenceRecordReaders.entrySet()) {
        SequenceRecordReader rr = entry.getValue();
        List<List<List<Writable>>> writables = new ArrayList<>(num);
        for (int i = 0; i < num && rr.hasNext(); i++) {
            List<List<Writable>> sequence;
            if (collectMetaData) {
                SequenceRecord r = rr.nextSequence();
                sequence = r.getSequenceRecord();
                if (nextMetas.size() <= i) {
                    nextMetas.add(new RecordMetaDataComposableMap(new HashMap<String, RecordMetaData>()));
                }
                RecordMetaDataComposableMap map = nextMetas.get(i);
                map.getMeta().put(entry.getKey(), r.getMetaData());
            } else {
                sequence = rr.sequenceRecord();
            }
            writables.add(sequence);
        }
        nextSeqRRVals.put(entry.getKey(), writables);
    }
    return nextMultiDataSet(nextRRVals, nextSeqRRVals, nextMetas);
}
Also used : SequenceRecordReader(org.datavec.api.records.reader.SequenceRecordReader) RecordReader(org.datavec.api.records.reader.RecordReader) SequenceRecordReader(org.datavec.api.records.reader.SequenceRecordReader) NDArrayWritable(org.datavec.common.data.NDArrayWritable) Writable(org.datavec.api.writable.Writable) SequenceRecord(org.datavec.api.records.SequenceRecord) SequenceRecord(org.datavec.api.records.SequenceRecord) Record(org.datavec.api.records.Record) RecordMetaDataComposableMap(org.datavec.api.records.metadata.RecordMetaDataComposableMap) RecordMetaDataComposableMap(org.datavec.api.records.metadata.RecordMetaDataComposableMap)

Aggregations

RecordMetaDataComposableMap (org.datavec.api.records.metadata.RecordMetaDataComposableMap)3 Writable (org.datavec.api.writable.Writable)3 NDArrayWritable (org.datavec.common.data.NDArrayWritable)3 Record (org.datavec.api.records.Record)2 SequenceRecord (org.datavec.api.records.SequenceRecord)2 RecordReader (org.datavec.api.records.reader.RecordReader)2 SequenceRecordReader (org.datavec.api.records.reader.SequenceRecordReader)2 RecordMetaData (org.datavec.api.records.metadata.RecordMetaData)1 INDArray (org.nd4j.linalg.api.ndarray.INDArray)1 MultiDataSet (org.nd4j.linalg.dataset.api.MultiDataSet)1