Search in sources :

Example 1 with DoubleWritable

use of org.datavec.api.writable.DoubleWritable in project deeplearning4j by deeplearning4j.

the class RecordReaderDataSetiteratorTest method testRecordReaderDataSetIteratorNDArrayWritableLabels.

@Test
public void testRecordReaderDataSetIteratorNDArrayWritableLabels() {
    Collection<Collection<Writable>> data = new ArrayList<>();
    data.add(Arrays.<Writable>asList(new DoubleWritable(0), new DoubleWritable(1), new NDArrayWritable(Nd4j.create(new double[] { 1.1, 2.1, 3.1 }))));
    data.add(Arrays.<Writable>asList(new DoubleWritable(2), new DoubleWritable(3), new NDArrayWritable(Nd4j.create(new double[] { 4.1, 5.1, 6.1 }))));
    data.add(Arrays.<Writable>asList(new DoubleWritable(4), new DoubleWritable(5), new NDArrayWritable(Nd4j.create(new double[] { 7.1, 8.1, 9.1 }))));
    RecordReader rr = new CollectionRecordReader(data);
    int batchSize = 3;
    int labelIndexFrom = 2;
    int labelIndexTo = 2;
    boolean regression = true;
    DataSetIterator rrdsi = new RecordReaderDataSetIterator(rr, batchSize, labelIndexFrom, labelIndexTo, regression);
    DataSet ds = rrdsi.next();
    INDArray expFeatures = Nd4j.create(new double[][] { { 0, 1 }, { 2, 3 }, { 4, 5 } });
    INDArray expLabels = Nd4j.create(new double[][] { { 1.1, 2.1, 3.1 }, { 4.1, 5.1, 6.1 }, { 7.1, 8.1, 9.1 } });
    assertEquals(expFeatures, ds.getFeatures());
    assertEquals(expLabels, ds.getLabels());
    //ALSO: test if we have NDArrayWritables for BOTH the features and the labels
    data = new ArrayList<>();
    data.add(Arrays.<Writable>asList(new NDArrayWritable(Nd4j.create(new double[] { 0, 1 })), new NDArrayWritable(Nd4j.create(new double[] { 1.1, 2.1, 3.1 }))));
    data.add(Arrays.<Writable>asList(new NDArrayWritable(Nd4j.create(new double[] { 2, 3 })), new NDArrayWritable(Nd4j.create(new double[] { 4.1, 5.1, 6.1 }))));
    data.add(Arrays.<Writable>asList(new NDArrayWritable(Nd4j.create(new double[] { 4, 5 })), new NDArrayWritable(Nd4j.create(new double[] { 7.1, 8.1, 9.1 }))));
    rr = new CollectionRecordReader(data);
    rrdsi = new RecordReaderDataSetIterator(rr, batchSize, labelIndexFrom, labelIndexTo, regression);
    ds = rrdsi.next();
    assertEquals(expFeatures, ds.getFeatures());
    assertEquals(expLabels, ds.getLabels());
}
Also used : DataSet(org.nd4j.linalg.dataset.DataSet) RecordReader(org.datavec.api.records.reader.RecordReader) CollectionRecordReader(org.datavec.api.records.reader.impl.collection.CollectionRecordReader) CSVSequenceRecordReader(org.datavec.api.records.reader.impl.csv.CSVSequenceRecordReader) CSVRecordReader(org.datavec.api.records.reader.impl.csv.CSVRecordReader) SequenceRecordReader(org.datavec.api.records.reader.SequenceRecordReader) CollectionSequenceRecordReader(org.datavec.api.records.reader.impl.collection.CollectionSequenceRecordReader) CollectionRecordReader(org.datavec.api.records.reader.impl.collection.CollectionRecordReader) DoubleWritable(org.datavec.api.writable.DoubleWritable) NDArrayWritable(org.datavec.common.data.NDArrayWritable) INDArray(org.nd4j.linalg.api.ndarray.INDArray) DataSetIterator(org.nd4j.linalg.dataset.api.iterator.DataSetIterator) Test(org.junit.Test)

Example 2 with DoubleWritable

use of org.datavec.api.writable.DoubleWritable in project deeplearning4j by deeplearning4j.

the class TestRecordReaders method testClassIndexOutsideOfRangeRRMDSI_MultipleReaders.

@Test
public void testClassIndexOutsideOfRangeRRMDSI_MultipleReaders() {
    Collection<Collection<Collection<Writable>>> c1 = new ArrayList<>();
    Collection<Collection<Writable>> seq1 = new ArrayList<>();
    seq1.add(Arrays.<Writable>asList(new DoubleWritable(0.0)));
    seq1.add(Arrays.<Writable>asList(new DoubleWritable(0.0)));
    c1.add(seq1);
    Collection<Collection<Writable>> seq2 = new ArrayList<>();
    seq2.add(Arrays.<Writable>asList(new DoubleWritable(0.0)));
    seq2.add(Arrays.<Writable>asList(new DoubleWritable(0.0)));
    c1.add(seq2);
    Collection<Collection<Collection<Writable>>> c2 = new ArrayList<>();
    Collection<Collection<Writable>> seq1a = new ArrayList<>();
    seq1a.add(Arrays.<Writable>asList(new IntWritable(0)));
    seq1a.add(Arrays.<Writable>asList(new IntWritable(1)));
    c2.add(seq1a);
    Collection<Collection<Writable>> seq2a = new ArrayList<>();
    seq2a.add(Arrays.<Writable>asList(new IntWritable(0)));
    seq2a.add(Arrays.<Writable>asList(new IntWritable(2)));
    c2.add(seq2a);
    CollectionSequenceRecordReader csrr = new CollectionSequenceRecordReader(c1);
    CollectionSequenceRecordReader csrrLabels = new CollectionSequenceRecordReader(c2);
    DataSetIterator dsi = new SequenceRecordReaderDataSetIterator(csrr, csrrLabels, 2, 2);
    try {
        DataSet ds = dsi.next();
        fail("Expected exception");
    } catch (DL4JException e) {
        System.out.println("testClassIndexOutsideOfRangeRRMDSI_MultipleReaders(): " + e.getMessage());
    } catch (Exception e) {
        e.printStackTrace();
        fail();
    }
}
Also used : SequenceRecordReaderDataSetIterator(org.deeplearning4j.datasets.datavec.SequenceRecordReaderDataSetIterator) DataSet(org.nd4j.linalg.dataset.api.DataSet) ArrayList(java.util.ArrayList) CollectionSequenceRecordReader(org.datavec.api.records.reader.impl.collection.CollectionSequenceRecordReader) IntWritable(org.datavec.api.writable.IntWritable) Writable(org.datavec.api.writable.Writable) DoubleWritable(org.datavec.api.writable.DoubleWritable) DoubleWritable(org.datavec.api.writable.DoubleWritable) DL4JException(org.deeplearning4j.exception.DL4JException) DL4JException(org.deeplearning4j.exception.DL4JException) Collection(java.util.Collection) IntWritable(org.datavec.api.writable.IntWritable) SequenceRecordReaderDataSetIterator(org.deeplearning4j.datasets.datavec.SequenceRecordReaderDataSetIterator) DataSetIterator(org.nd4j.linalg.dataset.api.iterator.DataSetIterator) RecordReaderDataSetIterator(org.deeplearning4j.datasets.datavec.RecordReaderDataSetIterator) Test(org.junit.Test)

Example 3 with DoubleWritable

use of org.datavec.api.writable.DoubleWritable in project deeplearning4j by deeplearning4j.

the class TestRecordReaders method testClassIndexOutsideOfRangeRRDSI.

@Test
public void testClassIndexOutsideOfRangeRRDSI() {
    Collection<Collection<Writable>> c = new ArrayList<>();
    c.add(Arrays.<Writable>asList(new DoubleWritable(0.0), new IntWritable(0)));
    c.add(Arrays.<Writable>asList(new DoubleWritable(0.0), new IntWritable(2)));
    CollectionRecordReader crr = new CollectionRecordReader(c);
    RecordReaderDataSetIterator iter = new RecordReaderDataSetIterator(crr, 2, 1, 2);
    try {
        DataSet ds = iter.next();
        fail("Expected exception");
    } catch (DL4JException e) {
        System.out.println("testClassIndexOutsideOfRange(): " + e.getMessage());
    } catch (Exception e) {
        e.printStackTrace();
        fail();
    }
}
Also used : DataSet(org.nd4j.linalg.dataset.api.DataSet) ArrayList(java.util.ArrayList) Collection(java.util.Collection) CollectionRecordReader(org.datavec.api.records.reader.impl.collection.CollectionRecordReader) DoubleWritable(org.datavec.api.writable.DoubleWritable) SequenceRecordReaderDataSetIterator(org.deeplearning4j.datasets.datavec.SequenceRecordReaderDataSetIterator) RecordReaderDataSetIterator(org.deeplearning4j.datasets.datavec.RecordReaderDataSetIterator) DL4JException(org.deeplearning4j.exception.DL4JException) IntWritable(org.datavec.api.writable.IntWritable) DL4JException(org.deeplearning4j.exception.DL4JException) Test(org.junit.Test)

Example 4 with DoubleWritable

use of org.datavec.api.writable.DoubleWritable in project deeplearning4j by deeplearning4j.

the class TestRecordReaders method testClassIndexOutsideOfRangeRRMDSI.

@Test
public void testClassIndexOutsideOfRangeRRMDSI() {
    Collection<Collection<Collection<Writable>>> c = new ArrayList<>();
    Collection<Collection<Writable>> seq1 = new ArrayList<>();
    seq1.add(Arrays.<Writable>asList(new DoubleWritable(0.0), new IntWritable(0)));
    seq1.add(Arrays.<Writable>asList(new DoubleWritable(0.0), new IntWritable(1)));
    c.add(seq1);
    Collection<Collection<Writable>> seq2 = new ArrayList<>();
    seq2.add(Arrays.<Writable>asList(new DoubleWritable(0.0), new IntWritable(0)));
    seq2.add(Arrays.<Writable>asList(new DoubleWritable(0.0), new IntWritable(2)));
    c.add(seq2);
    CollectionSequenceRecordReader csrr = new CollectionSequenceRecordReader(c);
    DataSetIterator dsi = new SequenceRecordReaderDataSetIterator(csrr, 2, 2, 1);
    try {
        DataSet ds = dsi.next();
        fail("Expected exception");
    } catch (DL4JException e) {
        System.out.println("testClassIndexOutsideOfRangeRRMDSI(): " + e.getMessage());
    } catch (Exception e) {
        e.printStackTrace();
        fail();
    }
}
Also used : SequenceRecordReaderDataSetIterator(org.deeplearning4j.datasets.datavec.SequenceRecordReaderDataSetIterator) DataSet(org.nd4j.linalg.dataset.api.DataSet) ArrayList(java.util.ArrayList) CollectionSequenceRecordReader(org.datavec.api.records.reader.impl.collection.CollectionSequenceRecordReader) IntWritable(org.datavec.api.writable.IntWritable) Writable(org.datavec.api.writable.Writable) DoubleWritable(org.datavec.api.writable.DoubleWritable) DoubleWritable(org.datavec.api.writable.DoubleWritable) DL4JException(org.deeplearning4j.exception.DL4JException) DL4JException(org.deeplearning4j.exception.DL4JException) Collection(java.util.Collection) IntWritable(org.datavec.api.writable.IntWritable) SequenceRecordReaderDataSetIterator(org.deeplearning4j.datasets.datavec.SequenceRecordReaderDataSetIterator) DataSetIterator(org.nd4j.linalg.dataset.api.iterator.DataSetIterator) RecordReaderDataSetIterator(org.deeplearning4j.datasets.datavec.RecordReaderDataSetIterator) Test(org.junit.Test)

Aggregations

DoubleWritable (org.datavec.api.writable.DoubleWritable)4 Test (org.junit.Test)4 ArrayList (java.util.ArrayList)3 Collection (java.util.Collection)3 CollectionSequenceRecordReader (org.datavec.api.records.reader.impl.collection.CollectionSequenceRecordReader)3 IntWritable (org.datavec.api.writable.IntWritable)3 RecordReaderDataSetIterator (org.deeplearning4j.datasets.datavec.RecordReaderDataSetIterator)3 SequenceRecordReaderDataSetIterator (org.deeplearning4j.datasets.datavec.SequenceRecordReaderDataSetIterator)3 DL4JException (org.deeplearning4j.exception.DL4JException)3 DataSet (org.nd4j.linalg.dataset.api.DataSet)3 DataSetIterator (org.nd4j.linalg.dataset.api.iterator.DataSetIterator)3 CollectionRecordReader (org.datavec.api.records.reader.impl.collection.CollectionRecordReader)2 Writable (org.datavec.api.writable.Writable)2 RecordReader (org.datavec.api.records.reader.RecordReader)1 SequenceRecordReader (org.datavec.api.records.reader.SequenceRecordReader)1 CSVRecordReader (org.datavec.api.records.reader.impl.csv.CSVRecordReader)1 CSVSequenceRecordReader (org.datavec.api.records.reader.impl.csv.CSVSequenceRecordReader)1 NDArrayWritable (org.datavec.common.data.NDArrayWritable)1 INDArray (org.nd4j.linalg.api.ndarray.INDArray)1 DataSet (org.nd4j.linalg.dataset.DataSet)1