Search in sources :

Example 26 with ClassPathResource

use of org.nd4j.linalg.io.ClassPathResource in project deeplearning4j by deeplearning4j.

the class TestGraphLoading method testGraphLoadingWithVertices.

@Test
public void testGraphLoadingWithVertices() throws IOException {
    ClassPathResource verticesCPR = new ClassPathResource("test_graph_vertices.txt");
    ClassPathResource edgesCPR = new ClassPathResource("test_graph_edges.txt");
    EdgeLineProcessor<String> edgeLineProcessor = new DelimitedEdgeLineProcessor(",", false, "//");
    VertexLoader<String> vertexLoader = new DelimitedVertexLoader(":", "//");
    Graph<String, String> graph = GraphLoader.loadGraph(verticesCPR.getTempFileFromArchive().getAbsolutePath(), edgesCPR.getTempFileFromArchive().getAbsolutePath(), vertexLoader, edgeLineProcessor, false);
    System.out.println(graph);
    for (int i = 0; i < 10; i++) {
        List<Edge<String>> edges = graph.getEdgesOut(i);
        assertEquals(2, edges.size());
        //expect for example 0->1 and 9->0
        Edge<String> first = edges.get(0);
        if (first.getFrom() == i) {
            //undirected edge: i -> i+1 (or 9 -> 0)
            assertEquals(i, first.getFrom());
            assertEquals((i + 1) % 10, first.getTo());
        } else {
            //undirected edge: i-1 -> i (or 9 -> 0)
            assertEquals((i + 10 - 1) % 10, first.getFrom());
            assertEquals(i, first.getTo());
        }
        Edge<String> second = edges.get(1);
        assertNotEquals(first.getFrom(), second.getFrom());
        if (second.getFrom() == i) {
            //undirected edge: i -> i+1 (or 9 -> 0)
            assertEquals(i, second.getFrom());
            assertEquals((i + 1) % 10, second.getTo());
        } else {
            //undirected edge: i-1 -> i (or 9 -> 0)
            assertEquals((i + 10 - 1) % 10, second.getFrom());
            assertEquals(i, second.getTo());
        }
    }
}
Also used : DelimitedEdgeLineProcessor(org.deeplearning4j.graph.data.impl.DelimitedEdgeLineProcessor) Edge(org.deeplearning4j.graph.api.Edge) ClassPathResource(org.nd4j.linalg.io.ClassPathResource) DelimitedVertexLoader(org.deeplearning4j.graph.data.impl.DelimitedVertexLoader) Test(org.junit.Test)

Example 27 with ClassPathResource

use of org.nd4j.linalg.io.ClassPathResource in project deeplearning4j by deeplearning4j.

the class TestDeepWalk method testDeepWalk13Vertices.

@Test
public void testDeepWalk13Vertices() throws IOException {
    int nVertices = 13;
    ClassPathResource cpr = new ClassPathResource("graph13.txt");
    Graph<String, String> graph = GraphLoader.loadUndirectedGraphEdgeListFile(cpr.getTempFileFromArchive().getAbsolutePath(), 13, ",");
    System.out.println(graph);
    Nd4j.getRandom().setSeed(12345);
    int nEpochs = 200;
    //Set up network
    DeepWalk<String, String> deepWalk = new DeepWalk.Builder<String, String>().vectorSize(50).windowSize(4).seed(12345).build();
    //Run learning
    for (int i = 0; i < nEpochs; i++) {
        deepWalk.setLearningRate(0.03 / nEpochs * (nEpochs - i));
        deepWalk.fit(graph, 10);
    }
    //Calculate similarity(0,i)
    for (int i = 0; i < nVertices; i++) {
        System.out.println(deepWalk.similarity(0, i));
    }
    for (int i = 0; i < nVertices; i++) System.out.println(deepWalk.getVertexVector(i));
}
Also used : ClassPathResource(org.nd4j.linalg.io.ClassPathResource) Test(org.junit.Test)

Example 28 with ClassPathResource

use of org.nd4j.linalg.io.ClassPathResource in project deeplearning4j by deeplearning4j.

the class TestSparkComputationGraph method testBasic.

@Test
public void testBasic() throws Exception {
    JavaSparkContext sc = this.sc;
    RecordReader rr = new CSVRecordReader(0, ",");
    rr.initialize(new FileSplit(new ClassPathResource("iris.txt").getTempFileFromArchive()));
    MultiDataSetIterator iter = new RecordReaderMultiDataSetIterator.Builder(1).addReader("iris", rr).addInput("iris", 0, 3).addOutputOneHot("iris", 4, 3).build();
    List<MultiDataSet> list = new ArrayList<>(150);
    while (iter.hasNext()) list.add(iter.next());
    ComputationGraphConfiguration config = new NeuralNetConfiguration.Builder().optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).learningRate(0.1).graphBuilder().addInputs("in").addLayer("dense", new DenseLayer.Builder().nIn(4).nOut(2).build(), "in").addLayer("out", new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT).nIn(2).nOut(3).build(), "dense").setOutputs("out").pretrain(false).backprop(true).build();
    ComputationGraph cg = new ComputationGraph(config);
    cg.init();
    TrainingMaster tm = new ParameterAveragingTrainingMaster(true, numExecutors(), 1, 10, 1, 0);
    SparkComputationGraph scg = new SparkComputationGraph(sc, cg, tm);
    scg.setListeners(Collections.singleton((IterationListener) new ScoreIterationListener(1)));
    JavaRDD<MultiDataSet> rdd = sc.parallelize(list);
    scg.fitMultiDataSet(rdd);
    //Try: fitting using DataSet
    DataSetIterator iris = new IrisDataSetIterator(1, 150);
    List<DataSet> list2 = new ArrayList<>();
    while (iris.hasNext()) list2.add(iris.next());
    JavaRDD<DataSet> rddDS = sc.parallelize(list2);
    scg.fit(rddDS);
}
Also used : IrisDataSetIterator(org.deeplearning4j.datasets.iterator.impl.IrisDataSetIterator) DataSet(org.nd4j.linalg.dataset.DataSet) MultiDataSet(org.nd4j.linalg.dataset.api.MultiDataSet) RecordReader(org.datavec.api.records.reader.RecordReader) CSVRecordReader(org.datavec.api.records.reader.impl.csv.CSVRecordReader) FileSplit(org.datavec.api.split.FileSplit) TrainingMaster(org.deeplearning4j.spark.api.TrainingMaster) ParameterAveragingTrainingMaster(org.deeplearning4j.spark.impl.paramavg.ParameterAveragingTrainingMaster) RecordReaderMultiDataSetIterator(org.deeplearning4j.datasets.datavec.RecordReaderMultiDataSetIterator) MultiDataSetIterator(org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator) CSVRecordReader(org.datavec.api.records.reader.impl.csv.CSVRecordReader) JavaSparkContext(org.apache.spark.api.java.JavaSparkContext) ComputationGraph(org.deeplearning4j.nn.graph.ComputationGraph) ScoreIterationListener(org.deeplearning4j.optimize.listeners.ScoreIterationListener) ParameterAveragingTrainingMaster(org.deeplearning4j.spark.impl.paramavg.ParameterAveragingTrainingMaster) ClassPathResource(org.nd4j.linalg.io.ClassPathResource) MultiDataSet(org.nd4j.linalg.dataset.api.MultiDataSet) DenseLayer(org.deeplearning4j.nn.conf.layers.DenseLayer) ComputationGraphConfiguration(org.deeplearning4j.nn.conf.ComputationGraphConfiguration) IterationListener(org.deeplearning4j.optimize.api.IterationListener) ScoreIterationListener(org.deeplearning4j.optimize.listeners.ScoreIterationListener) IrisDataSetIterator(org.deeplearning4j.datasets.iterator.impl.IrisDataSetIterator) DataSetIterator(org.nd4j.linalg.dataset.api.iterator.DataSetIterator) RecordReaderMultiDataSetIterator(org.deeplearning4j.datasets.datavec.RecordReaderMultiDataSetIterator) MultiDataSetIterator(org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator) BaseSparkTest(org.deeplearning4j.spark.BaseSparkTest) Test(org.junit.Test)

Example 29 with ClassPathResource

use of org.nd4j.linalg.io.ClassPathResource in project deeplearning4j by deeplearning4j.

the class TestSparkMultiLayerParameterAveraging method testFromSvmLightBackprop.

@Test
public void testFromSvmLightBackprop() throws Exception {
    JavaRDD<LabeledPoint> data = MLUtils.loadLibSVMFile(sc.sc(), new ClassPathResource("svmLight/iris_svmLight_0.txt").getTempFileFromArchive().getAbsolutePath()).toJavaRDD().map(new Function<LabeledPoint, LabeledPoint>() {

        @Override
        public LabeledPoint call(LabeledPoint v1) throws Exception {
            return new LabeledPoint(v1.label(), Vectors.dense(v1.features().toArray()));
        }
    });
    Nd4j.ENFORCE_NUMERICAL_STABILITY = true;
    DataSet d = new IrisDataSetIterator(150, 150).next();
    MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().seed(123).optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).iterations(10).list().layer(0, new DenseLayer.Builder().nIn(4).nOut(100).weightInit(WeightInit.XAVIER).activation(Activation.RELU).build()).layer(1, new org.deeplearning4j.nn.conf.layers.OutputLayer.Builder(LossFunctions.LossFunction.MCXENT).nIn(100).nOut(3).activation(Activation.SOFTMAX).weightInit(WeightInit.XAVIER).build()).backprop(true).build();
    MultiLayerNetwork network = new MultiLayerNetwork(conf);
    network.init();
    System.out.println("Initializing network");
    SparkDl4jMultiLayer master = new SparkDl4jMultiLayer(sc, conf, new ParameterAveragingTrainingMaster(true, numExecutors(), 1, 5, 1, 0));
    MultiLayerNetwork network2 = master.fitLabeledPoint(data);
    Evaluation evaluation = new Evaluation();
    evaluation.eval(d.getLabels(), network2.output(d.getFeatureMatrix()));
    System.out.println(evaluation.stats());
}
Also used : Evaluation(org.deeplearning4j.eval.Evaluation) IrisDataSetIterator(org.deeplearning4j.datasets.iterator.impl.IrisDataSetIterator) MultiDataSet(org.nd4j.linalg.dataset.MultiDataSet) DataSet(org.nd4j.linalg.dataset.DataSet) LabeledPoint(org.apache.spark.mllib.regression.LabeledPoint) ClassPathResource(org.nd4j.linalg.io.ClassPathResource) MultiLayerConfiguration(org.deeplearning4j.nn.conf.MultiLayerConfiguration) DenseLayer(org.deeplearning4j.nn.conf.layers.DenseLayer) SparkDl4jMultiLayer(org.deeplearning4j.spark.impl.multilayer.SparkDl4jMultiLayer) MultiLayerNetwork(org.deeplearning4j.nn.multilayer.MultiLayerNetwork) BaseSparkTest(org.deeplearning4j.spark.BaseSparkTest) Test(org.junit.Test)

Example 30 with ClassPathResource

use of org.nd4j.linalg.io.ClassPathResource in project deeplearning4j by deeplearning4j.

the class RecordReaderMultiDataSetIteratorTest method testSplittingCSVSequence.

@Test
public void testSplittingCSVSequence() throws Exception {
    //need to manually extract
    for (int i = 0; i < 3; i++) {
        new ClassPathResource(String.format("csvsequence_%d.txt", i)).getTempFileFromArchive();
        new ClassPathResource(String.format("csvsequencelabels_%d.txt", i)).getTempFileFromArchive();
        new ClassPathResource(String.format("csvsequencelabelsShort_%d.txt", i)).getTempFileFromArchive();
    }
    ClassPathResource resource = new ClassPathResource("csvsequence_0.txt");
    String featuresPath = resource.getTempFileFromArchive().getAbsolutePath().replaceAll("0", "%d");
    resource = new ClassPathResource("csvsequencelabels_0.txt");
    String labelsPath = resource.getTempFileFromArchive().getAbsolutePath().replaceAll("0", "%d");
    SequenceRecordReader featureReader = new CSVSequenceRecordReader(1, ",");
    SequenceRecordReader labelReader = new CSVSequenceRecordReader(1, ",");
    featureReader.initialize(new NumberedFileInputSplit(featuresPath, 0, 2));
    labelReader.initialize(new NumberedFileInputSplit(labelsPath, 0, 2));
    SequenceRecordReaderDataSetIterator iter = new SequenceRecordReaderDataSetIterator(featureReader, labelReader, 1, 4, false);
    SequenceRecordReader featureReader2 = new CSVSequenceRecordReader(1, ",");
    SequenceRecordReader labelReader2 = new CSVSequenceRecordReader(1, ",");
    featureReader2.initialize(new NumberedFileInputSplit(featuresPath, 0, 2));
    labelReader2.initialize(new NumberedFileInputSplit(labelsPath, 0, 2));
    MultiDataSetIterator srrmdsi = new RecordReaderMultiDataSetIterator.Builder(1).addSequenceReader("seq1", featureReader2).addSequenceReader("seq2", labelReader2).addInput("seq1", 0, 1).addInput("seq1", 2, 2).addOutputOneHot("seq2", 0, 4).build();
    while (iter.hasNext()) {
        DataSet ds = iter.next();
        INDArray fds = ds.getFeatureMatrix();
        INDArray lds = ds.getLabels();
        MultiDataSet mds = srrmdsi.next();
        assertEquals(2, mds.getFeatures().length);
        assertEquals(1, mds.getLabels().length);
        assertNull(mds.getFeaturesMaskArrays());
        assertNull(mds.getLabelsMaskArrays());
        INDArray[] fmds = mds.getFeatures();
        INDArray[] lmds = mds.getLabels();
        assertNotNull(fmds);
        assertNotNull(lmds);
        for (int i = 0; i < fmds.length; i++) assertNotNull(fmds[i]);
        for (int i = 0; i < lmds.length; i++) assertNotNull(lmds[i]);
        INDArray expIn1 = fds.get(NDArrayIndex.all(), NDArrayIndex.interval(0, 1, true), NDArrayIndex.all());
        INDArray expIn2 = fds.get(NDArrayIndex.all(), NDArrayIndex.interval(2, 2, true), NDArrayIndex.all());
        assertEquals(expIn1, fmds[0]);
        assertEquals(expIn2, fmds[1]);
        assertEquals(lds, lmds[0]);
    }
    assertFalse(srrmdsi.hasNext());
}
Also used : CSVSequenceRecordReader(org.datavec.api.records.reader.impl.csv.CSVSequenceRecordReader) SequenceRecordReader(org.datavec.api.records.reader.SequenceRecordReader) DataSet(org.nd4j.linalg.dataset.DataSet) MultiDataSet(org.nd4j.linalg.dataset.api.MultiDataSet) ClassPathResource(org.nd4j.linalg.io.ClassPathResource) NumberedFileInputSplit(org.datavec.api.split.NumberedFileInputSplit) MultiDataSetIterator(org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator) CSVSequenceRecordReader(org.datavec.api.records.reader.impl.csv.CSVSequenceRecordReader) INDArray(org.nd4j.linalg.api.ndarray.INDArray) MultiDataSet(org.nd4j.linalg.dataset.api.MultiDataSet) Test(org.junit.Test)

Aggregations

ClassPathResource (org.nd4j.linalg.io.ClassPathResource)64 Test (org.junit.Test)63 SequenceRecordReader (org.datavec.api.records.reader.SequenceRecordReader)23 CSVSequenceRecordReader (org.datavec.api.records.reader.impl.csv.CSVSequenceRecordReader)23 DataSet (org.nd4j.linalg.dataset.DataSet)23 MultiLayerNetwork (org.deeplearning4j.nn.multilayer.MultiLayerNetwork)20 FileSplit (org.datavec.api.split.FileSplit)18 INDArray (org.nd4j.linalg.api.ndarray.INDArray)18 File (java.io.File)17 CollectionSequenceRecordReader (org.datavec.api.records.reader.impl.collection.CollectionSequenceRecordReader)14 CSVRecordReader (org.datavec.api.records.reader.impl.csv.CSVRecordReader)13 RecordReader (org.datavec.api.records.reader.RecordReader)12 NumberedFileInputSplit (org.datavec.api.split.NumberedFileInputSplit)12 MultiLayerConfiguration (org.deeplearning4j.nn.conf.MultiLayerConfiguration)12 MultiDataSet (org.nd4j.linalg.dataset.api.MultiDataSet)11 RecordMetaData (org.datavec.api.records.metadata.RecordMetaData)8 MultiDataSetIterator (org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator)8 ImageRecordReader (org.datavec.image.recordreader.ImageRecordReader)7 LossMCXENT (org.nd4j.linalg.lossfunctions.impl.LossMCXENT)7 DataSetIterator (org.nd4j.linalg.dataset.api.iterator.DataSetIterator)6