Search in sources :

Example 21 with Writable

use of org.datavec.api.writable.Writable in project deeplearning4j by deeplearning4j.

the class NDArrayRecordToNDArray method convert.

@Override
public INDArray convert(Collection<Collection<Writable>> records) {
    INDArray[] concat = new INDArray[records.size()];
    int count = 0;
    for (Collection<Writable> record : records) {
        NDArrayWritable writable = (NDArrayWritable) record.iterator().next();
        concat[count++] = writable.get();
    }
    return Nd4j.concat(0, concat);
}
Also used : NDArrayWritable(org.datavec.common.data.NDArrayWritable) INDArray(org.nd4j.linalg.api.ndarray.INDArray) NDArrayWritable(org.datavec.common.data.NDArrayWritable) Writable(org.datavec.api.writable.Writable)

Example 22 with Writable

use of org.datavec.api.writable.Writable in project deeplearning4j by deeplearning4j.

the class DataVecDataSetFunction method call.

@Override
public DataSet call(List<Writable> currList) throws Exception {
    //allow people to specify label index as -1 and infer the last possible label
    int labelIndex = this.labelIndex;
    if (numPossibleLabels >= 1 && labelIndex < 0) {
        labelIndex = currList.size() - 1;
    }
    INDArray label = null;
    INDArray featureVector = null;
    int featureCount = 0;
    int labelCount = 0;
    //no labels
    if (currList.size() == 2 && currList.get(1) instanceof NDArrayWritable && currList.get(0) instanceof NDArrayWritable && currList.get(0) == currList.get(1)) {
        NDArrayWritable writable = (NDArrayWritable) currList.get(0);
        DataSet ds = new DataSet(writable.get(), writable.get());
        if (preProcessor != null)
            preProcessor.preProcess(ds);
        return ds;
    }
    if (currList.size() == 2 && currList.get(0) instanceof NDArrayWritable) {
        if (!regression)
            label = FeatureUtil.toOutcomeVector((int) Double.parseDouble(currList.get(1).toString()), numPossibleLabels);
        else
            label = Nd4j.scalar(Double.parseDouble(currList.get(1).toString()));
        NDArrayWritable ndArrayWritable = (NDArrayWritable) currList.get(0);
        featureVector = ndArrayWritable.get();
        DataSet ds = new DataSet(featureVector, label);
        if (preProcessor != null)
            preProcessor.preProcess(ds);
        return ds;
    }
    for (int j = 0; j < currList.size(); j++) {
        Writable current = currList.get(j);
        //ndarray writable is an insane slow down here
        if (!(current instanceof NDArrayWritable) && current.toString().isEmpty())
            continue;
        if (labelIndex >= 0 && j >= labelIndex && j <= labelIndexTo) {
            //single label case (classification, single label regression etc)
            if (converter != null) {
                try {
                    current = converter.convert(current);
                } catch (WritableConverterException e) {
                    e.printStackTrace();
                }
            }
            if (regression) {
                //single and multi-label regression
                if (label == null) {
                    label = Nd4j.zeros(labelIndexTo - labelIndex + 1);
                }
                label.putScalar(0, labelCount++, current.toDouble());
            } else {
                if (numPossibleLabels < 1)
                    throw new IllegalStateException("Number of possible labels invalid, must be >= 1 for classification");
                int curr = current.toInt();
                if (curr >= numPossibleLabels)
                    throw new IllegalStateException("Invalid index: got index " + curr + " but numPossibleLabels is " + numPossibleLabels + " (must be 0 <= idx < numPossibleLabels");
                label = FeatureUtil.toOutcomeVector(curr, numPossibleLabels);
            }
        } else {
            try {
                double value = current.toDouble();
                if (featureVector == null) {
                    if (regression && labelIndex >= 0) {
                        //Handle the possibly multi-label regression case here:
                        int nLabels = labelIndexTo - labelIndex + 1;
                        featureVector = Nd4j.create(1, currList.size() - nLabels);
                    } else {
                        //Classification case, and also no-labels case
                        featureVector = Nd4j.create(labelIndex >= 0 ? currList.size() - 1 : currList.size());
                    }
                }
                featureVector.putScalar(featureCount++, value);
            } catch (UnsupportedOperationException e) {
                // This isn't a scalar, so check if we got an array already
                if (current instanceof NDArrayWritable) {
                    assert featureVector == null;
                    featureVector = ((NDArrayWritable) current).get();
                } else {
                    throw e;
                }
            }
        }
    }
    DataSet ds = new DataSet(featureVector, (labelIndex >= 0 ? label : featureVector));
    if (preProcessor != null)
        preProcessor.preProcess(ds);
    return ds;
}
Also used : NDArrayWritable(org.datavec.common.data.NDArrayWritable) INDArray(org.nd4j.linalg.api.ndarray.INDArray) DataSet(org.nd4j.linalg.dataset.DataSet) NDArrayWritable(org.datavec.common.data.NDArrayWritable) Writable(org.datavec.api.writable.Writable) WritableConverterException(org.datavec.api.io.converters.WritableConverterException)

Example 23 with Writable

use of org.datavec.api.writable.Writable in project deeplearning4j by deeplearning4j.

the class DataVecSequencePairDataSetFunction method call.

@Override
public DataSet call(Tuple2<List<List<Writable>>, List<List<Writable>>> input) throws Exception {
    List<List<Writable>> featuresSeq = input._1();
    List<List<Writable>> labelsSeq = input._2();
    int featuresLength = featuresSeq.size();
    int labelsLength = labelsSeq.size();
    Iterator<List<Writable>> fIter = featuresSeq.iterator();
    Iterator<List<Writable>> lIter = labelsSeq.iterator();
    INDArray inputArr = null;
    INDArray outputArr = null;
    int[] idx = new int[3];
    int i = 0;
    while (fIter.hasNext()) {
        List<Writable> step = fIter.next();
        if (i == 0) {
            int[] inShape = new int[] { 1, step.size(), featuresLength };
            inputArr = Nd4j.create(inShape);
        }
        Iterator<Writable> timeStepIter = step.iterator();
        int f = 0;
        idx[1] = 0;
        while (timeStepIter.hasNext()) {
            Writable current = timeStepIter.next();
            if (converter != null)
                current = converter.convert(current);
            try {
                inputArr.putScalar(idx, current.toDouble());
            } catch (UnsupportedOperationException e) {
                // This isn't a scalar, so check if we got an array already
                if (current instanceof NDArrayWritable) {
                    inputArr.get(NDArrayIndex.point(idx[0]), NDArrayIndex.all(), NDArrayIndex.point(idx[2])).putRow(0, ((NDArrayWritable) current).get());
                } else {
                    throw e;
                }
            }
            idx[1] = ++f;
        }
        idx[2] = ++i;
    }
    idx = new int[3];
    i = 0;
    while (lIter.hasNext()) {
        List<Writable> step = lIter.next();
        if (i == 0) {
            int[] outShape = new int[] { 1, (regression ? step.size() : numPossibleLabels), labelsLength };
            outputArr = Nd4j.create(outShape);
        }
        Iterator<Writable> timeStepIter = step.iterator();
        int f = 0;
        idx[1] = 0;
        if (regression) {
            //Load all values without modification
            while (timeStepIter.hasNext()) {
                Writable current = timeStepIter.next();
                if (converter != null)
                    current = converter.convert(current);
                outputArr.putScalar(idx, current.toDouble());
                idx[1] = ++f;
            }
        } else {
            //Expect a single value (index) -> convert to one-hot vector
            Writable value = timeStepIter.next();
            int labelClassIdx = value.toInt();
            INDArray line = FeatureUtil.toOutcomeVector(labelClassIdx, numPossibleLabels);
            //1d from [1,nOut,timeSeriesLength] -> tensor i along dimension 1 is at time i
            outputArr.tensorAlongDimension(i, 1).assign(line);
        }
        idx[2] = ++i;
    }
    DataSet ds;
    if (alignmentMode == AlignmentMode.EQUAL_LENGTH || featuresLength == labelsLength) {
        ds = new DataSet(inputArr, outputArr);
    } else if (alignmentMode == AlignmentMode.ALIGN_END) {
        if (featuresLength > labelsLength) {
            //Input longer, pad output
            INDArray newOutput = Nd4j.create(1, outputArr.size(1), featuresLength);
            newOutput.get(NDArrayIndex.point(0), NDArrayIndex.all(), NDArrayIndex.interval(featuresLength - labelsLength, featuresLength)).assign(outputArr);
            //Need an output mask array, but not an input mask array
            INDArray outputMask = Nd4j.create(1, featuresLength);
            for (int j = featuresLength - labelsLength; j < featuresLength; j++) outputMask.putScalar(j, 1.0);
            ds = new DataSet(inputArr, newOutput, Nd4j.ones(outputMask.shape()), outputMask);
        } else {
            //Output longer, pad input
            INDArray newInput = Nd4j.create(1, inputArr.size(1), labelsLength);
            newInput.get(NDArrayIndex.point(0), NDArrayIndex.all(), NDArrayIndex.interval(labelsLength - featuresLength, labelsLength)).assign(inputArr);
            //Need an input mask array, but not an output mask array
            INDArray inputMask = Nd4j.create(1, labelsLength);
            for (int j = labelsLength - featuresLength; j < labelsLength; j++) inputMask.putScalar(j, 1.0);
            ds = new DataSet(newInput, outputArr, inputMask, Nd4j.ones(inputMask.shape()));
        }
    } else if (alignmentMode == AlignmentMode.ALIGN_START) {
        if (featuresLength > labelsLength) {
            //Input longer, pad output
            INDArray newOutput = Nd4j.create(1, outputArr.size(1), featuresLength);
            newOutput.get(NDArrayIndex.point(0), NDArrayIndex.all(), NDArrayIndex.interval(0, labelsLength)).assign(outputArr);
            //Need an output mask array, but not an input mask array
            INDArray outputMask = Nd4j.create(1, featuresLength);
            for (int j = 0; j < labelsLength; j++) outputMask.putScalar(j, 1.0);
            ds = new DataSet(inputArr, newOutput, Nd4j.ones(outputMask.shape()), outputMask);
        } else {
            //Output longer, pad input
            INDArray newInput = Nd4j.create(1, inputArr.size(1), labelsLength);
            newInput.get(NDArrayIndex.point(0), NDArrayIndex.all(), NDArrayIndex.interval(0, featuresLength)).assign(inputArr);
            //Need an input mask array, but not an output mask array
            INDArray inputMask = Nd4j.create(1, labelsLength);
            for (int j = 0; j < featuresLength; j++) inputMask.putScalar(j, 1.0);
            ds = new DataSet(newInput, outputArr, inputMask, Nd4j.ones(inputMask.shape()));
        }
    } else {
        throw new UnsupportedOperationException("Invalid alignment mode: " + alignmentMode);
    }
    if (preProcessor != null)
        preProcessor.preProcess(ds);
    return ds;
}
Also used : NDArrayWritable(org.datavec.common.data.NDArrayWritable) INDArray(org.nd4j.linalg.api.ndarray.INDArray) DataSet(org.nd4j.linalg.dataset.DataSet) NDArrayWritable(org.datavec.common.data.NDArrayWritable) Writable(org.datavec.api.writable.Writable) List(java.util.List)

Example 24 with Writable

use of org.datavec.api.writable.Writable in project deeplearning4j by deeplearning4j.

the class RecordReaderFunction method call.

@Override
public DataSet call(String v1) throws Exception {
    recordReader.initialize(new StringSplit(v1));
    List<DataSet> dataSets = new ArrayList<>();
    List<Writable> record = recordReader.next();
    List<Writable> currList;
    if (record instanceof List)
        currList = (List<Writable>) record;
    else
        currList = new ArrayList<>(record);
    INDArray label = null;
    INDArray featureVector = Nd4j.create(labelIndex >= 0 ? currList.size() - 1 : currList.size());
    int count = 0;
    for (int j = 0; j < currList.size(); j++) {
        if (labelIndex >= 0 && j == labelIndex) {
            if (numPossibleLabels < 1)
                throw new IllegalStateException("Number of possible labels invalid, must be >= 1");
            Writable current = currList.get(j);
            if (converter != null)
                current = converter.convert(current);
            label = FeatureUtil.toOutcomeVector(Double.valueOf(current.toString()).intValue(), numPossibleLabels);
        } else {
            Writable current = currList.get(j);
            featureVector.putScalar(count++, Double.valueOf(current.toString()));
        }
    }
    dataSets.add(new DataSet(featureVector, labelIndex >= 0 ? label : featureVector));
    List<INDArray> inputs = new ArrayList<>();
    List<INDArray> labels = new ArrayList<>();
    for (DataSet data : dataSets) {
        inputs.add(data.getFeatureMatrix());
        labels.add(data.getLabels());
    }
    DataSet ret = new DataSet(Nd4j.vstack(inputs.toArray(new INDArray[0])), Nd4j.vstack(labels.toArray(new INDArray[0])));
    return ret;
}
Also used : StringSplit(org.datavec.api.split.StringSplit) INDArray(org.nd4j.linalg.api.ndarray.INDArray) DataSet(org.nd4j.linalg.dataset.DataSet) ArrayList(java.util.ArrayList) Writable(org.datavec.api.writable.Writable) List(java.util.List) ArrayList(java.util.ArrayList)

Example 25 with Writable

use of org.datavec.api.writable.Writable in project deeplearning4j by deeplearning4j.

the class TestDataVecDataSetFunctions method testDataVecSequenceDataSetFunction.

@Test
public void testDataVecSequenceDataSetFunction() throws Exception {
    JavaSparkContext sc = getContext();
    //Test Spark record reader functionality vs. local
    File f = new File("src/test/resources/csvsequence/csvsequence_0.txt");
    String path = f.getPath();
    String folder = path.substring(0, path.length() - 17);
    path = folder + "*";
    JavaPairRDD<String, PortableDataStream> origData = sc.binaryFiles(path);
    //3 CSV sequences
    assertEquals(3, origData.count());
    SequenceRecordReader seqRR = new CSVSequenceRecordReader(1, ",");
    SequenceRecordReaderFunction rrf = new SequenceRecordReaderFunction(seqRR);
    JavaRDD<List<List<Writable>>> rdd = origData.map(rrf);
    JavaRDD<DataSet> data = rdd.map(new DataVecSequenceDataSetFunction(2, -1, true, null, null));
    List<DataSet> collected = data.collect();
    //Load normally (i.e., not via Spark), and check that we get the same results (order not withstanding)
    InputSplit is = new FileSplit(new File(folder), new String[] { "txt" }, true);
    SequenceRecordReader seqRR2 = new CSVSequenceRecordReader(1, ",");
    seqRR2.initialize(is);
    SequenceRecordReaderDataSetIterator iter = new SequenceRecordReaderDataSetIterator(seqRR2, 1, -1, 2, true);
    List<DataSet> listLocal = new ArrayList<>(3);
    while (iter.hasNext()) {
        listLocal.add(iter.next());
    }
    //Compare:
    assertEquals(3, collected.size());
    assertEquals(3, listLocal.size());
    //Check that results are the same (order not withstanding)
    boolean[] found = new boolean[3];
    for (int i = 0; i < 3; i++) {
        int foundIndex = -1;
        DataSet ds = collected.get(i);
        for (int j = 0; j < 3; j++) {
            if (ds.equals(listLocal.get(j))) {
                if (foundIndex != -1)
                    //Already found this value -> suggests this spark value equals two or more of local version? (Shouldn't happen)
                    fail();
                foundIndex = j;
                if (found[foundIndex])
                    //One of the other spark values was equal to this one -> suggests duplicates in Spark list
                    fail();
                //mark this one as seen before
                found[foundIndex] = true;
            }
        }
    }
    int count = 0;
    for (boolean b : found) if (b)
        count++;
    //Expect all 3 and exactly 3 pairwise matches between spark and local versions
    assertEquals(3, count);
}
Also used : CSVSequenceRecordReader(org.datavec.api.records.reader.impl.csv.CSVSequenceRecordReader) SequenceRecordReader(org.datavec.api.records.reader.SequenceRecordReader) SequenceRecordReaderFunction(org.datavec.spark.functions.SequenceRecordReaderFunction) DataSet(org.nd4j.linalg.dataset.DataSet) SequenceRecordReaderDataSetIterator(org.deeplearning4j.datasets.datavec.SequenceRecordReaderDataSetIterator) ArrayList(java.util.ArrayList) PortableDataStream(org.apache.spark.input.PortableDataStream) Writable(org.datavec.api.writable.Writable) FileSplit(org.datavec.api.split.FileSplit) CSVSequenceRecordReader(org.datavec.api.records.reader.impl.csv.CSVSequenceRecordReader) ArrayList(java.util.ArrayList) List(java.util.List) JavaSparkContext(org.apache.spark.api.java.JavaSparkContext) File(java.io.File) NumberedFileInputSplit(org.datavec.api.split.NumberedFileInputSplit) InputSplit(org.datavec.api.split.InputSplit) BaseSparkTest(org.deeplearning4j.spark.BaseSparkTest) Test(org.junit.Test)

Aggregations

Writable (org.datavec.api.writable.Writable)26 NDArrayWritable (org.datavec.common.data.NDArrayWritable)18 INDArray (org.nd4j.linalg.api.ndarray.INDArray)17 DataSet (org.nd4j.linalg.dataset.DataSet)12 ArrayList (java.util.ArrayList)8 SequenceRecordReader (org.datavec.api.records.reader.SequenceRecordReader)8 List (java.util.List)7 Test (org.junit.Test)7 CollectionSequenceRecordReader (org.datavec.api.records.reader.impl.collection.CollectionSequenceRecordReader)6 DoubleWritable (org.datavec.api.writable.DoubleWritable)6 IntWritable (org.datavec.api.writable.IntWritable)6 CSVSequenceRecordReader (org.datavec.api.records.reader.impl.csv.CSVSequenceRecordReader)5 Record (org.datavec.api.records.Record)3 RecordMetaDataComposableMap (org.datavec.api.records.metadata.RecordMetaDataComposableMap)3 SequenceRecordReaderDataSetIterator (org.deeplearning4j.datasets.datavec.SequenceRecordReaderDataSetIterator)3 DL4JInvalidInputException (org.deeplearning4j.exception.DL4JInvalidInputException)3 File (java.io.File)2 Collection (java.util.Collection)2 WritableConverterException (org.datavec.api.io.converters.WritableConverterException)2 SequenceRecord (org.datavec.api.records.SequenceRecord)2