Search in sources :

Example 46 with ClassPathResource

use of org.nd4j.linalg.io.ClassPathResource in project nd4j by deeplearning4j.

the class TensorFlowImportTest method testCrash_119_simpleif_0.

@Test
@Ignore
public void testCrash_119_simpleif_0() throws Exception {
    Nd4j.create(1);
    val tg = TFGraphMapper.getInstance().importGraph(new ClassPathResource("tf_graphs/examples/simpleif_0/frozen_model.pb").getInputStream());
    assertNotNull(tg);
    val input0 = Nd4j.create(new float[] { 1, 2, 3, 4 }, new int[] { 2, 2 });
    val input1 = Nd4j.trueScalar(11f);
    tg.associateArrayWithVariable(input0, tg.getVariable("input_0"));
    tg.associateArrayWithVariable(input1, tg.getVariable("input_1"));
    tg.asFlatFile(new File("../../../libnd4j/tests_cpu/resources/simpleif_0.fb"));
}
Also used : lombok.val(lombok.val) File(java.io.File) ClassPathResource(org.nd4j.linalg.io.ClassPathResource) Ignore(org.junit.Ignore) Test(org.junit.Test)

Example 47 with ClassPathResource

use of org.nd4j.linalg.io.ClassPathResource in project nd4j by deeplearning4j.

the class TensorFlowImportTest method testIntermediateTensorArraySimple1.

@Test
@Ignore
public void testIntermediateTensorArraySimple1() throws Exception {
    Nd4j.create(1);
    val tg = TFGraphMapper.getInstance().importGraph(new ClassPathResource("tf_graphs/tensor_array.pb.txt").getInputStream());
    tg.updateVariable("input_matrix", Nd4j.ones(3, 2));
    assertNotNull(tg);
    val firstSlice = tg.getVariable("strided_slice");
    val fb = tg.asFlatBuffers();
    assertNotNull(fb);
    val graph = FlatGraph.getRootAsFlatGraph(fb);
    assertEquals(36, graph.variablesLength());
    assertTrue(graph.nodesLength() > 1);
/*   assertEquals("strided_slice", graph.nodes(0).name());
        assertEquals("TensorArray", graph.nodes(1).name());
*/
// assertEquals(4, graph.nodes(0).inputPairedLength());
// tg.asFlatFile(new File("../../../libnd4j/tests_cpu/resources/tensor_array.fb"));
}
Also used : lombok.val(lombok.val) ClassPathResource(org.nd4j.linalg.io.ClassPathResource) Ignore(org.junit.Ignore) Test(org.junit.Test)

Example 48 with ClassPathResource

use of org.nd4j.linalg.io.ClassPathResource in project nd4j by deeplearning4j.

the class TestNDArrayCreation method testCreateNpy3.

@Test
public void testCreateNpy3() throws Exception {
    INDArray arrCreate = Nd4j.createFromNpyFile(new ClassPathResource("rank3.npy").getFile());
    assertEquals(8, arrCreate.length());
    assertEquals(3, arrCreate.rank());
    Pointer pointer = NativeOpsHolder.getInstance().getDeviceNativeOps().pointerForAddress(arrCreate.data().address());
    assertEquals(arrCreate.data().address(), pointer.address());
}
Also used : INDArray(org.nd4j.linalg.api.ndarray.INDArray) FloatPointer(org.bytedeco.javacpp.FloatPointer) Pointer(org.bytedeco.javacpp.Pointer) ClassPathResource(org.nd4j.linalg.io.ClassPathResource) Test(org.junit.Test) BaseNd4jTest(org.nd4j.linalg.BaseNd4jTest)

Example 49 with ClassPathResource

use of org.nd4j.linalg.io.ClassPathResource in project deeplearning4j by deeplearning4j.

the class RecordReaderMultiDataSetIteratorTest method testSplittingCSVSequence.

@Test
public void testSplittingCSVSequence() throws Exception {
    //need to manually extract
    for (int i = 0; i < 3; i++) {
        new ClassPathResource(String.format("csvsequence_%d.txt", i)).getTempFileFromArchive();
        new ClassPathResource(String.format("csvsequencelabels_%d.txt", i)).getTempFileFromArchive();
        new ClassPathResource(String.format("csvsequencelabelsShort_%d.txt", i)).getTempFileFromArchive();
    }
    ClassPathResource resource = new ClassPathResource("csvsequence_0.txt");
    String featuresPath = resource.getTempFileFromArchive().getAbsolutePath().replaceAll("0", "%d");
    resource = new ClassPathResource("csvsequencelabels_0.txt");
    String labelsPath = resource.getTempFileFromArchive().getAbsolutePath().replaceAll("0", "%d");
    SequenceRecordReader featureReader = new CSVSequenceRecordReader(1, ",");
    SequenceRecordReader labelReader = new CSVSequenceRecordReader(1, ",");
    featureReader.initialize(new NumberedFileInputSplit(featuresPath, 0, 2));
    labelReader.initialize(new NumberedFileInputSplit(labelsPath, 0, 2));
    SequenceRecordReaderDataSetIterator iter = new SequenceRecordReaderDataSetIterator(featureReader, labelReader, 1, 4, false);
    SequenceRecordReader featureReader2 = new CSVSequenceRecordReader(1, ",");
    SequenceRecordReader labelReader2 = new CSVSequenceRecordReader(1, ",");
    featureReader2.initialize(new NumberedFileInputSplit(featuresPath, 0, 2));
    labelReader2.initialize(new NumberedFileInputSplit(labelsPath, 0, 2));
    MultiDataSetIterator srrmdsi = new RecordReaderMultiDataSetIterator.Builder(1).addSequenceReader("seq1", featureReader2).addSequenceReader("seq2", labelReader2).addInput("seq1", 0, 1).addInput("seq1", 2, 2).addOutputOneHot("seq2", 0, 4).build();
    while (iter.hasNext()) {
        DataSet ds = iter.next();
        INDArray fds = ds.getFeatureMatrix();
        INDArray lds = ds.getLabels();
        MultiDataSet mds = srrmdsi.next();
        assertEquals(2, mds.getFeatures().length);
        assertEquals(1, mds.getLabels().length);
        assertNull(mds.getFeaturesMaskArrays());
        assertNull(mds.getLabelsMaskArrays());
        INDArray[] fmds = mds.getFeatures();
        INDArray[] lmds = mds.getLabels();
        assertNotNull(fmds);
        assertNotNull(lmds);
        for (int i = 0; i < fmds.length; i++) assertNotNull(fmds[i]);
        for (int i = 0; i < lmds.length; i++) assertNotNull(lmds[i]);
        INDArray expIn1 = fds.get(NDArrayIndex.all(), NDArrayIndex.interval(0, 1, true), NDArrayIndex.all());
        INDArray expIn2 = fds.get(NDArrayIndex.all(), NDArrayIndex.interval(2, 2, true), NDArrayIndex.all());
        assertEquals(expIn1, fmds[0]);
        assertEquals(expIn2, fmds[1]);
        assertEquals(lds, lmds[0]);
    }
    assertFalse(srrmdsi.hasNext());
}
Also used : CSVSequenceRecordReader(org.datavec.api.records.reader.impl.csv.CSVSequenceRecordReader) SequenceRecordReader(org.datavec.api.records.reader.SequenceRecordReader) DataSet(org.nd4j.linalg.dataset.DataSet) MultiDataSet(org.nd4j.linalg.dataset.api.MultiDataSet) ClassPathResource(org.nd4j.linalg.io.ClassPathResource) NumberedFileInputSplit(org.datavec.api.split.NumberedFileInputSplit) MultiDataSetIterator(org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator) CSVSequenceRecordReader(org.datavec.api.records.reader.impl.csv.CSVSequenceRecordReader) INDArray(org.nd4j.linalg.api.ndarray.INDArray) MultiDataSet(org.nd4j.linalg.dataset.api.MultiDataSet) Test(org.junit.Test)

Example 50 with ClassPathResource

use of org.nd4j.linalg.io.ClassPathResource in project deeplearning4j by deeplearning4j.

the class RecordReaderMultiDataSetIteratorTest method testSplittingCSVSequenceMeta.

@Test
public void testSplittingCSVSequenceMeta() throws Exception {
    //need to manually extract
    for (int i = 0; i < 3; i++) {
        new ClassPathResource(String.format("csvsequence_%d.txt", i)).getTempFileFromArchive();
        new ClassPathResource(String.format("csvsequencelabels_%d.txt", i)).getTempFileFromArchive();
        new ClassPathResource(String.format("csvsequencelabelsShort_%d.txt", i)).getTempFileFromArchive();
    }
    ClassPathResource resource = new ClassPathResource("csvsequence_0.txt");
    String featuresPath = resource.getTempFileFromArchive().getAbsolutePath().replaceAll("0", "%d");
    resource = new ClassPathResource("csvsequencelabels_0.txt");
    String labelsPath = resource.getTempFileFromArchive().getAbsolutePath().replaceAll("0", "%d");
    SequenceRecordReader featureReader = new CSVSequenceRecordReader(1, ",");
    SequenceRecordReader labelReader = new CSVSequenceRecordReader(1, ",");
    featureReader.initialize(new NumberedFileInputSplit(featuresPath, 0, 2));
    labelReader.initialize(new NumberedFileInputSplit(labelsPath, 0, 2));
    SequenceRecordReader featureReader2 = new CSVSequenceRecordReader(1, ",");
    SequenceRecordReader labelReader2 = new CSVSequenceRecordReader(1, ",");
    featureReader2.initialize(new NumberedFileInputSplit(featuresPath, 0, 2));
    labelReader2.initialize(new NumberedFileInputSplit(labelsPath, 0, 2));
    RecordReaderMultiDataSetIterator srrmdsi = new RecordReaderMultiDataSetIterator.Builder(1).addSequenceReader("seq1", featureReader2).addSequenceReader("seq2", labelReader2).addInput("seq1", 0, 1).addInput("seq1", 2, 2).addOutputOneHot("seq2", 0, 4).build();
    srrmdsi.setCollectMetaData(true);
    int count = 0;
    while (srrmdsi.hasNext()) {
        MultiDataSet mds = srrmdsi.next();
        MultiDataSet fromMeta = srrmdsi.loadFromMetaData(mds.getExampleMetaData(RecordMetaData.class));
        assertEquals(mds, fromMeta);
        count++;
    }
    assertEquals(3, count);
}
Also used : RecordMetaData(org.datavec.api.records.metadata.RecordMetaData) CSVSequenceRecordReader(org.datavec.api.records.reader.impl.csv.CSVSequenceRecordReader) SequenceRecordReader(org.datavec.api.records.reader.SequenceRecordReader) CSVSequenceRecordReader(org.datavec.api.records.reader.impl.csv.CSVSequenceRecordReader) MultiDataSet(org.nd4j.linalg.dataset.api.MultiDataSet) ClassPathResource(org.nd4j.linalg.io.ClassPathResource) NumberedFileInputSplit(org.datavec.api.split.NumberedFileInputSplit) Test(org.junit.Test)

Aggregations

ClassPathResource (org.nd4j.linalg.io.ClassPathResource)112 Test (org.junit.Test)100 lombok.val (lombok.val)31 INDArray (org.nd4j.linalg.api.ndarray.INDArray)26 SequenceRecordReader (org.datavec.api.records.reader.SequenceRecordReader)23 CSVSequenceRecordReader (org.datavec.api.records.reader.impl.csv.CSVSequenceRecordReader)23 DataSet (org.nd4j.linalg.dataset.DataSet)23 File (java.io.File)22 MultiLayerNetwork (org.deeplearning4j.nn.multilayer.MultiLayerNetwork)20 FileSplit (org.datavec.api.split.FileSplit)18 CollectionSequenceRecordReader (org.datavec.api.records.reader.impl.collection.CollectionSequenceRecordReader)14 Ignore (org.junit.Ignore)14 CSVRecordReader (org.datavec.api.records.reader.impl.csv.CSVRecordReader)13 RecordReader (org.datavec.api.records.reader.RecordReader)12 NumberedFileInputSplit (org.datavec.api.split.NumberedFileInputSplit)12 MultiLayerConfiguration (org.deeplearning4j.nn.conf.MultiLayerConfiguration)12 MultiDataSet (org.nd4j.linalg.dataset.api.MultiDataSet)11 MultiDataSetIterator (org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator)8 RecordMetaData (org.datavec.api.records.metadata.RecordMetaData)7 ImageRecordReader (org.datavec.image.recordreader.ImageRecordReader)7