Search in sources :

Example 1 with Writable

use of org.datavec.api.writable.Writable in project deeplearning4j by deeplearning4j.

the class RecordReaderDataSetIterator method getDataSet.

private DataSet getDataSet(List<Writable> record) {
    List<Writable> currList;
    if (record instanceof List)
        currList = record;
    else
        currList = new ArrayList<>(record);
    //allow people to specify label index as -1 and infer the last possible label
    if (numPossibleLabels >= 1 && labelIndex < 0) {
        labelIndex = record.size() - 1;
    }
    INDArray label = null;
    INDArray featureVector = null;
    int featureCount = 0;
    int labelCount = 0;
    //no labels
    if (currList.size() == 2 && currList.get(1) instanceof NDArrayWritable && currList.get(0) instanceof NDArrayWritable && currList.get(0) == currList.get(1)) {
        NDArrayWritable writable = (NDArrayWritable) currList.get(0);
        return new DataSet(writable.get(), writable.get());
    }
    if (currList.size() == 2 && currList.get(0) instanceof NDArrayWritable) {
        if (!regression) {
            label = FeatureUtil.toOutcomeVector((int) Double.parseDouble(currList.get(1).toString()), numPossibleLabels);
        } else {
            if (currList.get(1) instanceof NDArrayWritable) {
                label = ((NDArrayWritable) currList.get(1)).get();
            } else {
                label = Nd4j.scalar(currList.get(1).toDouble());
            }
        }
        NDArrayWritable ndArrayWritable = (NDArrayWritable) currList.get(0);
        featureVector = ndArrayWritable.get();
        return new DataSet(featureVector, label);
    }
    for (int j = 0; j < currList.size(); j++) {
        Writable current = currList.get(j);
        //ndarray writable is an insane slow down herecd
        if (!(current instanceof NDArrayWritable) && current.toString().isEmpty())
            continue;
        if (regression && j == labelIndex && j == labelIndexTo && current instanceof NDArrayWritable) {
            //Case: NDArrayWritable for the labels
            label = ((NDArrayWritable) current).get();
        } else if (regression && j >= labelIndex && j <= labelIndexTo) {
            //This is the multi-label regression case
            if (label == null)
                label = Nd4j.create(1, (labelIndexTo - labelIndex + 1));
            label.putScalar(labelCount++, current.toDouble());
        } else if (labelIndex >= 0 && j == labelIndex) {
            //single label case (classification, etc)
            if (converter != null)
                try {
                    current = converter.convert(current);
                } catch (WritableConverterException e) {
                    e.printStackTrace();
                }
            if (numPossibleLabels < 1)
                throw new IllegalStateException("Number of possible labels invalid, must be >= 1");
            if (regression) {
                label = Nd4j.scalar(current.toDouble());
            } else {
                int curr = current.toInt();
                if (curr < 0 || curr >= numPossibleLabels) {
                    throw new DL4JInvalidInputException("Invalid classification data: expect label value (at label index column = " + labelIndex + ") to be in range 0 to " + (numPossibleLabels - 1) + " inclusive (0 to numClasses-1, with numClasses=" + numPossibleLabels + "); got label value of " + current);
                }
                label = FeatureUtil.toOutcomeVector(curr, numPossibleLabels);
            }
        } else {
            try {
                double value = current.toDouble();
                if (featureVector == null) {
                    if (regression && labelIndex >= 0) {
                        //Handle the possibly multi-label regression case here:
                        int nLabels = labelIndexTo - labelIndex + 1;
                        featureVector = Nd4j.create(1, currList.size() - nLabels);
                    } else {
                        //Classification case, and also no-labels case
                        featureVector = Nd4j.create(labelIndex >= 0 ? currList.size() - 1 : currList.size());
                    }
                }
                featureVector.putScalar(featureCount++, value);
            } catch (UnsupportedOperationException e) {
                // This isn't a scalar, so check if we got an array already
                if (current instanceof NDArrayWritable) {
                    assert featureVector == null;
                    featureVector = ((NDArrayWritable) current).get();
                } else {
                    throw e;
                }
            }
        }
    }
    return new DataSet(featureVector, labelIndex >= 0 ? label : featureVector);
}
Also used : DataSet(org.nd4j.linalg.dataset.DataSet) ArrayList(java.util.ArrayList) NDArrayWritable(org.datavec.common.data.NDArrayWritable) Writable(org.datavec.api.writable.Writable) WritableConverterException(org.datavec.api.io.converters.WritableConverterException) NDArrayWritable(org.datavec.common.data.NDArrayWritable) INDArray(org.nd4j.linalg.api.ndarray.INDArray) ArrayList(java.util.ArrayList) List(java.util.List) DL4JInvalidInputException(org.deeplearning4j.exception.DL4JInvalidInputException)

Example 2 with Writable

use of org.datavec.api.writable.Writable in project deeplearning4j by deeplearning4j.

the class RecordReaderDataSetIterator method next.

@Override
public DataSet next(int num) {
    if (useCurrent) {
        useCurrent = false;
        if (preProcessor != null)
            preProcessor.preProcess(last);
        return last;
    }
    List<DataSet> dataSets = new ArrayList<>();
    List<RecordMetaData> meta = (collectMetaData ? new ArrayList<RecordMetaData>() : null);
    for (int i = 0; i < num; i++) {
        if (!hasNext())
            break;
        if (recordReader instanceof SequenceRecordReader) {
            if (sequenceIter == null || !sequenceIter.hasNext()) {
                List<List<Writable>> sequenceRecord = ((SequenceRecordReader) recordReader).sequenceRecord();
                sequenceIter = sequenceRecord.iterator();
            }
            List<Writable> record = sequenceIter.next();
            dataSets.add(getDataSet(record));
        } else {
            if (collectMetaData) {
                Record record = recordReader.nextRecord();
                dataSets.add(getDataSet(record.getRecord()));
                meta.add(record.getMetaData());
            } else {
                List<Writable> record = recordReader.next();
                dataSets.add(getDataSet(record));
            }
        }
    }
    batchNum++;
    if (dataSets.isEmpty())
        return new DataSet();
    DataSet ret = DataSet.merge(dataSets);
    if (collectMetaData) {
        ret.setExampleMetaData(meta);
    }
    last = ret;
    if (preProcessor != null)
        preProcessor.preProcess(ret);
    //Add label name values to dataset
    if (recordReader.getLabels() != null)
        ret.setLabelNames(recordReader.getLabels());
    return ret;
}
Also used : RecordMetaData(org.datavec.api.records.metadata.RecordMetaData) SequenceRecordReader(org.datavec.api.records.reader.SequenceRecordReader) DataSet(org.nd4j.linalg.dataset.DataSet) ArrayList(java.util.ArrayList) NDArrayWritable(org.datavec.common.data.NDArrayWritable) Writable(org.datavec.api.writable.Writable) ArrayList(java.util.ArrayList) List(java.util.List) Record(org.datavec.api.records.Record)

Example 3 with Writable

use of org.datavec.api.writable.Writable in project deeplearning4j by deeplearning4j.

the class RecordReaderMultiDataSetIterator method loadFromMetaData.

/**
     * Load a multiple sequence examples to a DataSet, using the provided RecordMetaData instances.
     *
     * @param list List of RecordMetaData instances to load from. Should have been produced by the record reader provided
     *             to the SequenceRecordReaderDataSetIterator constructor
     * @return DataSet with the specified examples
     * @throws IOException If an error occurs during loading of the data
     */
public MultiDataSet loadFromMetaData(List<RecordMetaData> list) throws IOException {
    //First: load the next values from the RR / SeqRRs
    Map<String, List<List<Writable>>> nextRRVals = new HashMap<>();
    Map<String, List<List<List<Writable>>>> nextSeqRRVals = new HashMap<>();
    List<RecordMetaDataComposableMap> nextMetas = (collectMetaData ? new ArrayList<RecordMetaDataComposableMap>() : null);
    for (Map.Entry<String, RecordReader> entry : recordReaders.entrySet()) {
        RecordReader rr = entry.getValue();
        List<RecordMetaData> thisRRMeta = new ArrayList<>();
        for (RecordMetaData m : list) {
            RecordMetaDataComposableMap m2 = (RecordMetaDataComposableMap) m;
            thisRRMeta.add(m2.getMeta().get(entry.getKey()));
        }
        List<Record> fromMeta = rr.loadFromMetaData(thisRRMeta);
        List<List<Writable>> writables = new ArrayList<>(list.size());
        for (Record r : fromMeta) {
            writables.add(r.getRecord());
        }
        nextRRVals.put(entry.getKey(), writables);
    }
    for (Map.Entry<String, SequenceRecordReader> entry : sequenceRecordReaders.entrySet()) {
        SequenceRecordReader rr = entry.getValue();
        List<RecordMetaData> thisRRMeta = new ArrayList<>();
        for (RecordMetaData m : list) {
            RecordMetaDataComposableMap m2 = (RecordMetaDataComposableMap) m;
            thisRRMeta.add(m2.getMeta().get(entry.getKey()));
        }
        List<SequenceRecord> fromMeta = rr.loadSequenceFromMetaData(thisRRMeta);
        List<List<List<Writable>>> writables = new ArrayList<>(list.size());
        for (SequenceRecord r : fromMeta) {
            writables.add(r.getSequenceRecord());
        }
        nextSeqRRVals.put(entry.getKey(), writables);
    }
    return nextMultiDataSet(nextRRVals, nextSeqRRVals, nextMetas);
}
Also used : RecordMetaData(org.datavec.api.records.metadata.RecordMetaData) SequenceRecordReader(org.datavec.api.records.reader.SequenceRecordReader) RecordReader(org.datavec.api.records.reader.RecordReader) SequenceRecordReader(org.datavec.api.records.reader.SequenceRecordReader) NDArrayWritable(org.datavec.common.data.NDArrayWritable) Writable(org.datavec.api.writable.Writable) SequenceRecord(org.datavec.api.records.SequenceRecord) SequenceRecord(org.datavec.api.records.SequenceRecord) Record(org.datavec.api.records.Record) RecordMetaDataComposableMap(org.datavec.api.records.metadata.RecordMetaDataComposableMap) RecordMetaDataComposableMap(org.datavec.api.records.metadata.RecordMetaDataComposableMap)

Example 4 with Writable

use of org.datavec.api.writable.Writable in project deeplearning4j by deeplearning4j.

the class RecordReaderMultiDataSetIterator method nextMultiDataSet.

private MultiDataSet nextMultiDataSet(Map<String, List<List<Writable>>> nextRRVals, Map<String, List<List<List<Writable>>>> nextSeqRRVals, List<RecordMetaDataComposableMap> nextMetas) {
    int minExamples = Integer.MAX_VALUE;
    for (List<List<Writable>> exampleData : nextRRVals.values()) {
        minExamples = Math.min(minExamples, exampleData.size());
    }
    for (List<List<List<Writable>>> exampleData : nextSeqRRVals.values()) {
        minExamples = Math.min(minExamples, exampleData.size());
    }
    if (minExamples == Integer.MAX_VALUE)
        //Should never happen
        throw new RuntimeException("Error occurred during data set generation: no readers?");
    //In order to align data at the end (for each example individually), we need to know the length of the
    // longest time series for each example
    int[] longestSequence = null;
    if (alignmentMode == AlignmentMode.ALIGN_END) {
        longestSequence = new int[minExamples];
        for (Map.Entry<String, List<List<List<Writable>>>> entry : nextSeqRRVals.entrySet()) {
            List<List<List<Writable>>> list = entry.getValue();
            for (int i = 0; i < list.size() && i < minExamples; i++) {
                longestSequence[i] = Math.max(longestSequence[i], list.get(i).size());
            }
        }
    }
    //Second: create the input arrays
    //To do this, we need to know longest time series length, so we can do padding
    int longestTS = -1;
    if (alignmentMode != AlignmentMode.EQUAL_LENGTH) {
        for (Map.Entry<String, List<List<List<Writable>>>> entry : nextSeqRRVals.entrySet()) {
            List<List<List<Writable>>> list = entry.getValue();
            for (List<List<Writable>> c : list) {
                longestTS = Math.max(longestTS, c.size());
            }
        }
    }
    INDArray[] inputArrs = new INDArray[inputs.size()];
    INDArray[] inputArrMasks = new INDArray[inputs.size()];
    boolean inputMasks = false;
    int i = 0;
    for (SubsetDetails d : inputs) {
        if (nextRRVals.containsKey(d.readerName)) {
            //Standard reader
            List<List<Writable>> list = nextRRVals.get(d.readerName);
            inputArrs[i] = convertWritables(list, minExamples, d);
        } else {
            //Sequence reader
            List<List<List<Writable>>> list = nextSeqRRVals.get(d.readerName);
            Pair<INDArray, INDArray> p = convertWritablesSequence(list, minExamples, longestTS, d, longestSequence);
            inputArrs[i] = p.getFirst();
            inputArrMasks[i] = p.getSecond();
            if (inputArrMasks[i] != null)
                inputMasks = true;
        }
        i++;
    }
    if (!inputMasks)
        inputArrMasks = null;
    //Third: create the outputs
    INDArray[] outputArrs = new INDArray[outputs.size()];
    INDArray[] outputArrMasks = new INDArray[outputs.size()];
    boolean outputMasks = false;
    i = 0;
    for (SubsetDetails d : outputs) {
        if (nextRRVals.containsKey(d.readerName)) {
            //Standard reader
            List<List<Writable>> list = nextRRVals.get(d.readerName);
            outputArrs[i] = convertWritables(list, minExamples, d);
        } else {
            //Sequence reader
            List<List<List<Writable>>> list = nextSeqRRVals.get(d.readerName);
            Pair<INDArray, INDArray> p = convertWritablesSequence(list, minExamples, longestTS, d, longestSequence);
            outputArrs[i] = p.getFirst();
            outputArrMasks[i] = p.getSecond();
            if (outputArrMasks[i] != null)
                outputMasks = true;
        }
        i++;
    }
    if (!outputMasks)
        outputArrMasks = null;
    MultiDataSet mds = new org.nd4j.linalg.dataset.MultiDataSet(inputArrs, outputArrs, inputArrMasks, outputArrMasks);
    if (collectMetaData) {
        mds.setExampleMetaData(nextMetas);
    }
    if (preProcessor != null)
        preProcessor.preProcess(mds);
    return mds;
}
Also used : NDArrayWritable(org.datavec.common.data.NDArrayWritable) Writable(org.datavec.api.writable.Writable) INDArray(org.nd4j.linalg.api.ndarray.INDArray) MultiDataSet(org.nd4j.linalg.dataset.api.MultiDataSet) RecordMetaDataComposableMap(org.datavec.api.records.metadata.RecordMetaDataComposableMap)

Example 5 with Writable

use of org.datavec.api.writable.Writable in project deeplearning4j by deeplearning4j.

the class SequenceRecordReaderDataSetIterator method getLabels.

private INDArray getLabels(List<List<Writable>> labels) {
    //Size of the record?
    //[timeSeriesLength,vectorSize]
    int[] shape = new int[2];
    //time series/sequence length
    shape[0] = labels.size();
    Iterator<List<Writable>> iter = labels.iterator();
    int i = 0;
    INDArray out = null;
    while (iter.hasNext()) {
        List<Writable> step = iter.next();
        if (i == 0) {
            if (regression) {
                for (Writable w : step) {
                    if (w instanceof NDArrayWritable) {
                        shape[1] += ((NDArrayWritable) w).get().length();
                    } else {
                        shape[1]++;
                    }
                }
            } else {
                shape[1] = numPossibleLabels;
            }
            out = Nd4j.create(shape, 'f');
        }
        Iterator<Writable> timeStepIter = step.iterator();
        int f = 0;
        if (regression) {
            //Load all values
            while (timeStepIter.hasNext()) {
                Writable current = timeStepIter.next();
                if (current instanceof NDArrayWritable) {
                    INDArray w = ((NDArrayWritable) current).get();
                    out.put(new INDArrayIndex[] { NDArrayIndex.point(i), NDArrayIndex.interval(f, f + w.length()) }, w);
                    f += w.length();
                } else {
                    out.put(i, f++, current.toDouble());
                }
            }
        } else {
            //Expect a single value (index) -> convert to one-hot vector
            Writable value = timeStepIter.next();
            int idx = value.toInt();
            if (idx < 0 || idx >= numPossibleLabels) {
                throw new DL4JInvalidInputException("Invalid classification data: expect label value to be in range 0 to " + (numPossibleLabels - 1) + " inclusive (0 to numClasses-1, with numClasses=" + numPossibleLabels + "); got label value of " + idx);
            }
            INDArray line = FeatureUtil.toOutcomeVector(idx, numPossibleLabels);
            out.getRow(i).assign(line);
        }
        i++;
    }
    return out;
}
Also used : NDArrayWritable(org.datavec.common.data.NDArrayWritable) INDArray(org.nd4j.linalg.api.ndarray.INDArray) NDArrayWritable(org.datavec.common.data.NDArrayWritable) Writable(org.datavec.api.writable.Writable) DL4JInvalidInputException(org.deeplearning4j.exception.DL4JInvalidInputException)

Aggregations

Writable (org.datavec.api.writable.Writable)26 NDArrayWritable (org.datavec.common.data.NDArrayWritable)18 INDArray (org.nd4j.linalg.api.ndarray.INDArray)17 DataSet (org.nd4j.linalg.dataset.DataSet)12 ArrayList (java.util.ArrayList)8 SequenceRecordReader (org.datavec.api.records.reader.SequenceRecordReader)8 List (java.util.List)7 Test (org.junit.Test)7 CollectionSequenceRecordReader (org.datavec.api.records.reader.impl.collection.CollectionSequenceRecordReader)6 DoubleWritable (org.datavec.api.writable.DoubleWritable)6 IntWritable (org.datavec.api.writable.IntWritable)6 CSVSequenceRecordReader (org.datavec.api.records.reader.impl.csv.CSVSequenceRecordReader)5 Record (org.datavec.api.records.Record)3 RecordMetaDataComposableMap (org.datavec.api.records.metadata.RecordMetaDataComposableMap)3 SequenceRecordReaderDataSetIterator (org.deeplearning4j.datasets.datavec.SequenceRecordReaderDataSetIterator)3 DL4JInvalidInputException (org.deeplearning4j.exception.DL4JInvalidInputException)3 File (java.io.File)2 Collection (java.util.Collection)2 WritableConverterException (org.datavec.api.io.converters.WritableConverterException)2 SequenceRecord (org.datavec.api.records.SequenceRecord)2