Search in sources :

Example 1 with Graph

use of org.deeplearning4j.models.sequencevectors.graph.primitives.Graph in project deeplearning4j by deeplearning4j.

the class SequenceVectorsTest method buildGraph.

private static Graph<Blogger, Double> buildGraph() throws IOException, InterruptedException {
    File nodes = new File("/ext/Temp/BlogCatalog/nodes.csv");
    CSVRecordReader reader = new CSVRecordReader(0, ",");
    reader.initialize(new FileSplit(nodes));
    List<Blogger> bloggers = new ArrayList<>();
    int cnt = 0;
    while (reader.hasNext()) {
        List<Writable> lines = new ArrayList<>(reader.next());
        Blogger blogger = new Blogger(lines.get(0).toInt());
        bloggers.add(blogger);
        cnt++;
    }
    reader.close();
    Graph<Blogger, Double> graph = new Graph<>(bloggers, true);
    // load edges
    File edges = new File("/ext/Temp/BlogCatalog/edges.csv");
    reader = new CSVRecordReader(0, ",");
    reader.initialize(new FileSplit(edges));
    while (reader.hasNext()) {
        List<Writable> lines = new ArrayList<>(reader.next());
        int from = lines.get(0).toInt();
        int to = lines.get(1).toInt();
        graph.addEdge(from - 1, to - 1, 1.0, false);
    }
    logger.info("Connected on 0: [" + graph.getConnectedVertices(0).size() + "]");
    logger.info("Connected on 1: [" + graph.getConnectedVertices(1).size() + "]");
    logger.info("Connected on 3: [" + graph.getConnectedVertices(3).size() + "]");
    assertEquals(119, graph.getConnectedVertices(0).size());
    assertEquals(9, graph.getConnectedVertices(1).size());
    assertEquals(6, graph.getConnectedVertices(3).size());
    return graph;
}
Also used : Graph(org.deeplearning4j.models.sequencevectors.graph.primitives.Graph) CSVRecordReader(org.datavec.api.records.reader.impl.csv.CSVRecordReader) ArrayList(java.util.ArrayList) Writable(org.datavec.api.writable.Writable) FileSplit(org.datavec.api.split.FileSplit) File(java.io.File)

Aggregations

File (java.io.File)1 ArrayList (java.util.ArrayList)1 CSVRecordReader (org.datavec.api.records.reader.impl.csv.CSVRecordReader)1 FileSplit (org.datavec.api.split.FileSplit)1 Writable (org.datavec.api.writable.Writable)1 Graph (org.deeplearning4j.models.sequencevectors.graph.primitives.Graph)1