use of org.deeplearning4j.models.sequencevectors.graph.primitives.Graph in project deeplearning4j by deeplearning4j.
the class SequenceVectorsTest method buildGraph.
private static Graph<Blogger, Double> buildGraph() throws IOException, InterruptedException {
File nodes = new File("/ext/Temp/BlogCatalog/nodes.csv");
CSVRecordReader reader = new CSVRecordReader(0, ",");
reader.initialize(new FileSplit(nodes));
List<Blogger> bloggers = new ArrayList<>();
int cnt = 0;
while (reader.hasNext()) {
List<Writable> lines = new ArrayList<>(reader.next());
Blogger blogger = new Blogger(lines.get(0).toInt());
bloggers.add(blogger);
cnt++;
}
reader.close();
Graph<Blogger, Double> graph = new Graph<>(bloggers, true);
// load edges
File edges = new File("/ext/Temp/BlogCatalog/edges.csv");
reader = new CSVRecordReader(0, ",");
reader.initialize(new FileSplit(edges));
while (reader.hasNext()) {
List<Writable> lines = new ArrayList<>(reader.next());
int from = lines.get(0).toInt();
int to = lines.get(1).toInt();
graph.addEdge(from - 1, to - 1, 1.0, false);
}
logger.info("Connected on 0: [" + graph.getConnectedVertices(0).size() + "]");
logger.info("Connected on 1: [" + graph.getConnectedVertices(1).size() + "]");
logger.info("Connected on 3: [" + graph.getConnectedVertices(3).size() + "]");
assertEquals(119, graph.getConnectedVertices(0).size());
assertEquals(9, graph.getConnectedVertices(1).size());
assertEquals(6, graph.getConnectedVertices(3).size());
return graph;
}
Aggregations